[CK Tile] Unification work - mma transformations pipeline (#5508)

## Motivation

In this PR we showcase how the amdgcn structs could be used in a pipeline that does some extra pre/post processing.
For the sparse intrinsics, so far we compressed the A vector "on the fly" right before the execution of the builtin. This might introduce performance issues down the line if, for example, the user decided to chain multiple sparse builtins. We tackle this problem by creating a specific SparseCompressTransform.

A MmaPipelineBase is also created to facilitate those kind of higher level compositions of the amdgcn structs and is integrated to the existing WaveWiseMma prototype. There is an effort to facilitate future operations, like swizzle A/B, C transpose or double/quad attr num access through the MmaPipelineOptionFlags, but those are not yet defined and should do so in a future PR.
The pipeline base class is basically at the RFC stage.

We also create a runtime test for the existing WaveWiseMma, as well as one for the SparseMma pipeline.

## Technical Details

The goal should be to have the pipeline easily expandable. May the CRTP of the base class or the interface in general be insufficient or unable to handle all of our needs, then a design modification should be discussed.

## Test Plan

New tests are added.

## Test Result

Tests should pass.

---------

Signed-off-by: Chris Tsiaousis <chris.tsiaousis@streamhpc.com>
This commit is contained in:
chris-tsiaousis-hpc
2026-04-14 09:25:01 +02:00
committed by GitHub
parent 5eee93e67c
commit 89c5e67028
20 changed files with 1580 additions and 585 deletions

View File

@@ -19,14 +19,16 @@
#include "ck_tile/core/arch/mma/mfma/mfma_selector.hpp"
#include "ck_tile/core/arch/mma/mfma/mfma_traits.hpp"
#include "ck_tile/core/arch/mma/mfma/mfma_transforms.hpp"
#include "ck_tile/core/arch/mma/mma.hpp"
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
#include "ck_tile/core/arch/mma/mma_pipeline.hpp"
#include "ck_tile/core/arch/mma/mma_selector.hpp"
#include "ck_tile/core/arch/mma/mma_traits.hpp"
#include "ck_tile/core/arch/mma/mma_transforms.hpp"
#include "ck_tile/core/arch/mma/mma_wavewise.hpp"
#include "ck_tile/core/arch/mma/sparse/mfma/selector.hpp"
#include "ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse_selector.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse_traits.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse_transforms.hpp"

View File

@@ -204,6 +204,20 @@ struct Unsupported;
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
#include <concepts>
/**
* @concept HasExecSignature
* @brief Helper concept for exec signature check.
*/
template <typename MmaOp, typename... ExecArgs>
concept HasExecSignature = requires {
{
MmaOp::exec(typename MmaOp::AVecType{},
typename MmaOp::BVecType{},
typename MmaOp::CVecType{},
std::declval<ExecArgs>()...)
} -> std::convertible_to<typename MmaOp::CVecType>;
};
/**
* @concept MmaOpI
* @brief Expresses the meta-data interface required for each MmaOp policy.
@@ -213,7 +227,7 @@ template <typename MmaOp>
concept MmaOpI = requires(MmaOp op) {
// Requires an op context
typename MmaOp::OpType;
typename MmaOp::OpFamily;
{ MmaOp::OpFamily } -> std::convertible_to<MmaOpFamily>;
// Captures types for inputs / outputs to mma function
typename MmaOp::ADataType;
@@ -230,13 +244,8 @@ concept MmaOpI = requires(MmaOp op) {
{ MmaOp::kBRepeat } -> std::convertible_to<unsigned int>;
{ MmaOp::kCMPerLane } -> std::convertible_to<unsigned int>;
{ MmaOp::kCMNumAccess } -> std::convertible_to<unsigned int>;
// Static exec function
{
MmaOp::exec(
typename MmaOp::AVecType{}, typename MmaOp::BVecType{}, typename MmaOp::CVecType{})
} -> std::convertible_to<typename MmaOp::CVecType>;
};
{ MmaOp::kCompressionRatio } -> std::convertible_to<unsigned int>;
} && (HasExecSignature<MmaOp> || HasExecSignature<MmaOp, int>);
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
@@ -275,7 +284,7 @@ struct amdgcn_mma : amdgcn_mma_base<fp32_t, fp32_t, fp32_t, 1u, 1u, 1u, 1u, 1, 1
// clang-format on
{
// This is a default pass-through implementation that doesn't do anything practical.
CK_TILE_DEVICE static CVecType const&
CK_TILE_DEVICE static auto
exec(AVecType const& regsA, BVecType const& regsB, CVecType const& regsC)
{
// Prints once across all thread blocks and threads.

View File

@@ -28,9 +28,11 @@ struct MmaDefaultTransformsGfx9
// TODO: c++20 template <MmaOpI MmaOp, amdgcn_target_arch_id CompilerTarget>
// TODO: c++20 requires
template <typename MmaOp, typename CompilerTarget>
struct MmaTransformsDefaultSelector<MmaOp,
CompilerTarget,
enable_if_target_family_gfx9_t<CompilerTarget>>
struct MmaTransformsDefaultSelector<
MmaOp,
CompilerTarget,
enable_if_all<enable_if_target_family_gfx9_t<CompilerTarget>,
std::enable_if_t<MmaOp::OpFamily == MmaOpFamily::DENSE>>>
{
using SelectedTransforms = MmaDefaultTransformsGfx9;
};

View File

@@ -1,230 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "amdgcn_mma.hpp"
#include "mma_selector.hpp"
#include "mma_transforms.hpp"
#include "mfma/mfma.hpp"
#include "wmma/wmma.hpp"
namespace ck_tile::core::arch::mma {
/*! @enum MmaAccumPolicy
* @brief Accumulation order for Mma decomposition
*/
enum struct MmaAccumPolicy
{
// Decomposition and accumulation in row-major fragment order
ROW_MAJOR,
// Decomposition and accumulation in col-major fragment order
COL_MAJOR
};
/**
* @class Mma
* @brief Driver for the wave-tile Mma operation. Given a backend MmaOp implementation
* (e.g., mfma or wmma), this class performs fragment-wise (MmaTile) decomposition to
* matrix-multiply input WaveTiles of (A: WaveTileM x WaveTileK) x (B: WaveTileK x WaveTileN) and
* accumulates results into output WaveTile (C: WaveTileM x WaveTileN).
* @tparam ADataType Data type of input WaveTile A
* @tparam BDataType Data type of input WaveTile B
* @tparam CDataType Data type of input/output WaveTile C (accumulator)
* @tparam WaveTileM Mma WaveTile M dimension
* @tparam WaveTileN Mma WaveTile K dimension
* @tparam WaveTileK Mma WaveTile M dimension
* @tparam AccumPolicy The fragment order of the accum. registers (row or col major frag order)
* @tparam CompilerTarget The compiler target
* @tparam MmaOp Backend wrapper class that will perform the mma op (e.g., mfma or wmma)
* @tparam MmaTransforms The set of transforms to be applied to input/output WaveTiles
* @par This is an example of an Mma decomposition driver class that can be used in a wave-tile
* context. Given a WaveTile size, we can decompose the WaveTile into smaller mma op fragments
* that are natively supported by the hardware (e.g., mfma or wmma). The class also supports
* applying transforms to the input/output frags as needed (e.g., layout conversions, data type
* conversions, etc.). We may also specify the accumulation order (row-major or col-major) for the
* output WaveTile. This is a powerful example of how to build a flexible and reusable mma driver
* that can adapt to different hardware capabilities and requirements.
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
uint32_t WaveTileM,
uint32_t WaveTileN,
uint32_t WaveTileK,
MmaOpFamily OpFamily,
MmaAccumPolicy AccumPolicy = MmaAccumPolicy::ROW_MAJOR,
typename CompilerTarget =
decltype(get_compiler_target()), // TODO: c++20 amdgcn_target_arch_id GfxTargetId =
// get_compiler_target(),
typename MmaOp =
typename MmaDefaultSelector<ADataType, // TODO: c++20 MmaOpI MmaOp = typename
// MmaDefaultSelector<ADataType,
BDataType,
CDataType,
WaveTileM,
WaveTileN,
WaveTileK,
CompilerTarget,
OpFamily>::SelectedOp,
typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms =
typename MmaTransformsDefaultSelector<MmaOp, CompilerTarget>::SelectedTransforms>
struct WaveWiseMma
{
using FragWiseMmaOp = MmaOp;
// Fragment dimensions
constexpr static uint32_t FragM = MmaOp::kM;
constexpr static uint32_t FragN = MmaOp::kN;
constexpr static uint32_t FragK = MmaOp::kK;
// Fragment counts for decomposition
constexpr static uint32_t FragsM = WaveTileM / FragM;
constexpr static uint32_t FragsN = WaveTileN / FragN;
constexpr static uint32_t FragsK = WaveTileK / FragK;
constexpr static uint32_t FragsC = FragsM * FragsN;
// Vector types for packed registers in each fragment
using AVecType = typename MmaOp::AVecType;
using BVecType = typename MmaOp::BVecType;
using CVecType = typename MmaOp::CVecType;
// Buffer types for WaveTiles
using ABufferType = AVecType[FragsM][FragsK];
using BBufferType = BVecType[FragsN][FragsK];
using CBufferType = CVecType[FragsM][FragsN];
// Transforms
using ATransform = typename MmaTransforms::ATransform;
using BTransform = typename MmaTransforms::BTransform;
using CTransform = typename MmaTransforms::CTransform;
using DTransform = typename MmaTransforms::DTransform;
// Sanity checks
static_assert(WaveTileM >= FragM, "WaveTileM must be larger than FragM");
static_assert(WaveTileN >= FragN, "WaveTileN must be larger than FragN");
static_assert(WaveTileK >= FragK, "WaveTileK must be larger than FragK");
static_assert(WaveTileM % FragM == 0u, "WaveTileM must be a multiple of FragM");
static_assert(WaveTileN % FragN == 0u, "WaveTileN must be a multiple of FragN");
static_assert(WaveTileK % FragK == 0u, "WaveTileK must be a multiple of FragK");
private:
template <typename DstT, typename SrcT>
CK_TILE_DEVICE static auto formatBuffer(SrcT const& inputBuffer)
{
// TODO: Implement formatting logic as needed.
// This is intended to convert input WaveTiles to the native vector types
// required by the FragWiseMma operation for iteration
static_assert(sizeof(DstT) == sizeof(SrcT), "Size mismatch in formatBuffer");
return reinterpret_cast<DstT const&>(inputBuffer);
}
template <typename DstT, typename SrcT>
CK_TILE_DEVICE static auto formatBuffer(SrcT& inputBuffer)
{
// TODO: Implement formatting logic as needed.
// This is intended to convert input WaveTiles to the native vector types
// required by the FragWiseMma operation for iteration
static_assert(sizeof(DstT) == sizeof(SrcT), "Size mismatch in formatBuffer");
return reinterpret_cast<DstT&>(inputBuffer);
}
/*! @brief Execute Mma in row-major accumulation order.
* @tparam VecTA The input WaveTile A vector type
* @tparam VecTB The input WaveTile B vector type
* @tparam VecTC The input/output WaveTile C vector type
*/
template <typename VecTA, typename VecTB, typename VecTC>
CK_TILE_DEVICE static decltype(auto) exec_col_major(VecTA&& a, VecTB&& b, VecTC&& accum)
{
// We implement an example wave-tile pipeline here.
// First, we apply the necessary transforms to the input fragments,
// then we convert the result into buffers of native vector formats
// that we can easily index. Native vector formats are necessary inputs
// to the given MmaOp exec function.
auto a_frag = formatBuffer<ABufferType>(ATransform::exec(a));
auto b_frag = formatBuffer<BBufferType>(BTransform::exec(b));
auto c_frag = formatBuffer<CBufferType>(CTransform::exec(accum));
// "Col-major" accumulation over the M-dimension fragments first.
// Pseudo code here, but we would basically iterate over the fragments in col-major order
for(uint32_t bn = 0u; bn < FragsN; ++bn)
{
for(uint32_t bm = 0u; bm < FragsM; ++bm)
{
for(uint32_t bk = 0u; bk < FragsK; ++bk)
{
c_frag[bm][bn] =
FragWiseMmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]);
}
}
}
// Convert native vector results back to the output WaveTile format
// and then return after we apply the final output transform.
return DTransform::exec(formatBuffer<std::decay_t<VecTC>>(c_frag));
}
/*! @brief Execute Mma in row-major accumulation order.
* @tparam VecTA The input WaveTile A vector type
* @tparam VecTB The input WaveTile B vector type
* @tparam VecTC The input/output WaveTile C vector type
*/
template <typename VecTA, typename VecTB, typename VecTC>
CK_TILE_DEVICE static decltype(auto) exec_row_major(VecTA&& a, VecTB&& b, VecTC&& accum)
{
// We implement an example wave-tile pipeline here.
// First, we apply the necessary transforms to the input WaveTiles,
// then we convert the result into buffers of native vector formats
// that we can easily index. Native vector formats are necessary inputs
// to the given MmaOp exec function.
auto a_frag = formatBuffer<ABufferType>(ATransform::exec(a));
auto b_frag = formatBuffer<BBufferType>(BTransform::exec(b));
auto c_frag = formatBuffer<CBufferType>(CTransform::exec(accum));
// "Row-major" accumulation over the N-dimension fragments first.
// Pseudo code here, but we would basically iterate over the fragments in row-major order.
// We also have to ensure that the incoming vector WaveTiles are converted to native vector
// types before passing to the FragWiseMma exec function.
for(uint32_t bm = 0u; bm < FragsM; ++bm)
{
for(uint32_t bn = 0u; bn < FragsN; ++bn)
{
for(uint32_t bk = 0u; bk < FragsK; ++bk)
{
c_frag[bm][bn] =
FragWiseMmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]);
}
}
}
// Convert native vector results back to the output WaveTile format
// and then return after we apply the final output transform.
return DTransform::exec(formatBuffer<std::decay_t<VecTC>>(c_frag));
}
public:
/*! @brief Forward to Mma operation with specified accumulation order.
* @tparam VecTA The input WaveTile A vector type
* @tparam VecTB The input WaveTile B vector type
* @tparam VecTC The input/output WaveTile C vector type
*/
template <typename VecTA, typename VecTB, typename VecTC>
CK_TILE_DEVICE static decltype(auto) exec(VecTA&& a, VecTB&& b, VecTC&& accum)
{
if constexpr(AccumPolicy == MmaAccumPolicy::ROW_MAJOR)
{
return exec_row_major(
std::forward<VecTA>(a), std::forward<VecTB>(b), std::forward<VecTC>(accum));
}
else // if constexpr(AccumPolicy == MmaAccumPolicy::COL_MAJOR)
{
return exec_col_major(
std::forward<VecTA>(a), std::forward<VecTB>(b), std::forward<VecTC>(accum));
}
}
};
} // namespace ck_tile::core::arch::mma

View File

@@ -0,0 +1,299 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "amdgcn_mma.hpp"
#include "mma_selector.hpp"
#include "mma_traits.hpp"
#include "mma_transforms.hpp"
namespace ck_tile::core::arch::mma {
/*! @enum MmaPipelineOptionFlag
* @brief Individual option flags for configuring MmaPipeline behavior.
*/
enum struct MmaPipelineOptionFlag : unsigned
{
NONE = 0x0, ///< No flags set
ABSwap = 0x1, ///< Swap A and B inputs to transpose the C output
COMPRESS_A = 0x2, ///< Enable compressed (sparse) A matrix input
};
/**
* @struct MmaPipelineOptionFlags
* @brief Type-safe bitmask wrapper for combining @ref MmaPipelineOptionFlag values.
* @par Provides bitwise OR, AND, NOT, and equality operators for composing
* and querying pipeline option flags.
*/
struct MmaPipelineOptionFlags
{
using Type = std::underlying_type_t<MmaPipelineOptionFlag>;
explicit constexpr MmaPipelineOptionFlags() : mFlags(0) {}
explicit constexpr MmaPipelineOptionFlags(Type value) : mFlags(value) {}
constexpr MmaPipelineOptionFlags(MmaPipelineOptionFlag singleFlag) : mFlags(toType(singleFlag))
{
}
constexpr MmaPipelineOptionFlags(const MmaPipelineOptionFlags& original)
: mFlags(original.mFlags)
{
}
constexpr MmaPipelineOptionFlags& operator|=(MmaPipelineOptionFlag addValue)
{
mFlags |= toType(addValue);
return *this;
}
constexpr MmaPipelineOptionFlags operator|(MmaPipelineOptionFlag addValue) const
{
MmaPipelineOptionFlags result(*this);
result |= addValue;
return result;
}
constexpr MmaPipelineOptionFlags& operator&=(MmaPipelineOptionFlag maskValue)
{
mFlags &= toType(maskValue);
return *this;
}
constexpr MmaPipelineOptionFlags operator&(MmaPipelineOptionFlag maskValue) const
{
MmaPipelineOptionFlags result(*this);
result &= maskValue;
return result;
}
constexpr MmaPipelineOptionFlags operator~() const
{
MmaPipelineOptionFlags result(*this);
result.mFlags = ~result.mFlags;
return result;
}
constexpr bool testFlag(MmaPipelineOptionFlag flag) const
{
return (flag == MmaPipelineOptionFlag::NONE) ? mFlags == toType(flag) : *this & flag;
}
constexpr operator bool() const { return mFlags != toType(MmaPipelineOptionFlag::NONE); }
constexpr bool operator==(Type rhs) const { return mFlags == rhs; }
private:
Type mFlags;
static constexpr Type toType(MmaPipelineOptionFlag f) { return static_cast<Type>(f); }
};
constexpr bool operator==(MmaPipelineOptionFlags::Type lhs, const MmaPipelineOptionFlags& rhs)
{
return rhs == lhs;
}
/**
* @class MmaPipelineBase
* @brief CRTP base class that implements the common Mma pipeline logic shared by
* all concrete pipeline drivers (e.g., dense wave-wise, sparse, etc.).
*
* @tparam Flags_ Compile-time bitmask of @ref MmaPipelineOptionFlag controlling
* pipeline behavior (e.g., C transposition, A compression).
* @tparam Derived The concrete CRTP-derived pipeline class. Must expose:
* - Type aliases: @c InternalAVecT, @c InternalBVecT, @c InternalCVecT,
* @c CVecType, @c MmaOp
* - Transform aliases: @c ATransform, @c BTransform, @c CTransform,
* @c DTransform
* - A static @c execImpl(std::tuple<A,B,C>&) method.
*
* @par The pipeline performs the following steps in @c exec():
* 1. Apply pre-transforms and format input buffers (A, B, C).
* 2. Delegate to @c Derived::execImpl for the actual mma loop.
* 3. Apply post-transform and format the output buffer (D) back to the user type.
* When @c ABSwap is set, the A and B inputs are swapped before step 1.
*/
// TODO: c++20: use MmaPipelineOptionFlags directly
template <MmaPipelineOptionFlags::Type Flags_, typename Derived>
struct MmaPipelineBase
{
static constexpr auto Flags = MmaPipelineOptionFlags(Flags_);
private:
/**
* @brief Reconstruct a tuple with its first element passed through @c formatBuffer
* while preserving all remaining elements unchanged.
* @tparam DstT Target type for the formatted first element.
* @tparam SrcT Forwarding-reference type of the input tuple.
* @tparam Is Index pack for elements 1..N-1 of the tuple.
* @param inputTuple The source tuple whose first element will be formatted.
* @return A new tuple with the formatted first element and the remaining elements forwarded.
*/
template <typename DstT, typename SrcT, std::size_t... Is>
CK_TILE_DEVICE static auto formatBufferTupleImpl(SrcT&& inputTuple, std::index_sequence<Is...>)
{
auto&& first_elem = std::get<0>(std::forward<SrcT>(inputTuple));
using FirstElemResultType =
decltype(formatBuffer<DstT>(std::forward<decltype(first_elem)>(first_elem)));
using InputTupleType = ck_tile::remove_cvref_t<SrcT>;
return std::tuple<FirstElemResultType, std::tuple_element_t<Is + 1, InputTupleType>...>(
formatBuffer<DstT>(std::forward<decltype(first_elem)>(first_elem)),
std::get<Is + 1>(std::forward<SrcT>(inputTuple))...);
}
/**
* @brief Format (reinterpret-cast) a buffer to the hardware-native vector type @p DstT.
*
* Three cases are handled:
* - **Tuple**: recursively format the first element via @c formatBufferTupleImpl,
* preserving any metadata in the remaining tuple elements.
* - **Array / Pointer**: forwarded unchanged.
* - **Scalar / Vector**: reinterpret-cast to @p DstT (sizes must match).
*
* @tparam DstT The target hardware vector type.
* @tparam SrcT Forwarding-reference type of the input buffer.
* @param inputBuffer The buffer to format.
* @return A reference (or value) of type @p DstT corresponding to @p inputBuffer.
*/
template <typename DstT, typename SrcT>
CK_TILE_DEVICE static decltype(auto) formatBuffer(SrcT&& inputBuffer)
{
using DecayedSrcT = ck_tile::remove_cvref_t<SrcT>;
// If SrcT is a tuple, extract the first element (the vector) and format it
// while preserving all remaining elements (metadata)
if constexpr(is_std_tuple_v<DecayedSrcT>)
{
// Create index sequence for all remaining elements (skip first)
constexpr std::size_t tuple_size = std::tuple_size_v<DecayedSrcT>;
return formatBufferTupleImpl<DstT>(std::forward<SrcT>(inputBuffer),
std::make_index_sequence<tuple_size - 1>{});
}
else if constexpr(std::is_array_v<DecayedSrcT> || std::is_pointer_v<DecayedSrcT>)
{
return std::forward<SrcT>(inputBuffer);
}
else
{
static_assert(sizeof(DstT) == sizeof(DecayedSrcT), "Size mismatch in formatBuffer");
using QualifiedDstT =
std::conditional_t<std::is_const_v<DecayedSrcT>, DstT const, DstT>;
return reinterpret_cast<QualifiedDstT&>(inputBuffer);
}
}
protected:
/** @brief Query whether a specific @ref MmaPipelineOptionFlag is set. */
template <MmaPipelineOptionFlag Flag>
constexpr CK_TILE_DEVICE static bool hasFlag()
{
return Flags.testFlag(Flag);
}
/**
* @brief Apply a transform **then** format the result to @p DstT.
* Used for input operands (A, B, C) before the mma loop.
*/
template <typename DstT, typename Transform, typename... Args>
CK_TILE_DEVICE static auto preApplyTransform(Args&&... args)
{
return formatBuffer<DstT>(Transform::exec(std::forward<Args>(args)...));
}
/**
* @brief Format a buffer to @p DstT **then** apply a transform.
* Used for the output operand (D) after the mma loop.
*/
template <typename DstT, typename Transform, typename... Args>
CK_TILE_DEVICE static auto postApplyTransform(Args&&... args)
{
return Transform::exec(formatBuffer<DstT>(std::forward<Args>(args)...));
}
/**
* @brief Apply the per-operand pre-transforms and buffer formatting to A, B, and C.
* @return A @c std::tuple of the transformed (A, B, C) vectors ready for the mma loop.
*/
template <typename ATransformInputs, typename BTransformInputs, typename CTransformInputs>
CK_TILE_DEVICE static decltype(auto)
applyTransformsToInputs(ATransformInputs&& a, BTransformInputs&& b, CTransformInputs&& accum)
{
using InternalAVecT = typename Derived::InternalAVecT;
using InternalBVecT = typename Derived::InternalBVecT;
using InternalCVecT = typename Derived::InternalCVecT;
using ATransform = typename Derived::ATransform;
using BTransform = typename Derived::BTransform;
using CTransform = typename Derived::CTransform;
return std::make_tuple(
preApplyTransform<InternalAVecT, ATransform>(std::forward<ATransformInputs>(a)),
preApplyTransform<InternalBVecT, BTransform>(std::forward<BTransformInputs>(b)),
preApplyTransform<InternalCVecT, CTransform>(std::forward<CTransformInputs>(accum)));
}
/**
* @brief Apply the post-transform and buffer formatting to the C (accumulator) output.
* @param vecs The (A, B, C) tuple after @c execImpl; only C is consumed.
* @return The final D output in the user-facing vector type.
*/
template <typename ATransformResult, typename BTransformResult, typename CTransformResult>
CK_TILE_DEVICE static auto
applyTransformToOutput(std::tuple<ATransformResult, BTransformResult, CTransformResult>&& vecs)
{
auto&& [a_result, b_result, c_result] = vecs;
static_assert(!is_std_tuple_v<decltype(c_result)>,
"If CTransform returns more than the vector, update this function.");
using CVecT = typename Derived::CVecType;
using DTransform = typename Derived::DTransform;
return postApplyTransform<CVecT, DTransform>(c_result);
}
public:
/**
* @brief Entry point: execute the full Mma pipeline (transforms + mma loop + output).
* @tparam VecTA Type of the A WaveTile buffer.
* @tparam VecTB Type of the B WaveTile buffer.
* @tparam VecTC Type of the C (accumulator) WaveTile buffer.
* @param a Input WaveTile A.
* @param b Input WaveTile B.
* @param accum Input/output accumulator WaveTile C.
* @return The output WaveTile D after accumulation and post-transform.
*/
template <typename VecTA, typename VecTB, typename VecTC>
CK_TILE_DEVICE static decltype(auto) exec(VecTA&& a, VecTB&& b, VecTC&& accum)
{
if constexpr(MmaOpTraits<typename Derived::MmaOp>::IsSupported)
{
auto transformed_inputs = applyTransformsToInputs(
hasFlag<MmaPipelineOptionFlag::ABSwap>() ? std::forward<VecTB>(b)
: std::forward<VecTA>(a),
hasFlag<MmaPipelineOptionFlag::ABSwap>() ? std::forward<VecTA>(a)
: std::forward<VecTB>(b),
std::forward<VecTC>(accum));
Derived::execImpl(transformed_inputs);
return applyTransformToOutput(std::move(transformed_inputs));
}
else
{
// Return the unsupported exec. This should print a runtime warning. (amdgcn_mma.hpp)
// Code should not reach here, but HOST/DEVICE compile passes are
// weirdly intertwined and instead of having constexpr in the calling
// site (tests) we do this. See also changes by this commit.
return Derived::MmaOp::exec({}, {}, {});
}
}
};
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
#include <concepts>
/**
* @concept MmaPipelineI
* @brief Expresses the meta-data interface required for a CRTP MmaPipeline.
*/
template <typename Derived, MmaPipelineOptionFlags::Type Flags>
concept MmaPipelineInterface = std::derived_from<Derived, MmaPipelineBase<Flags, Derived>>;
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
} // namespace ck_tile::core::arch::mma

View File

@@ -72,3 +72,4 @@ concept MmaSelectorI = requires(MmaSelector op) {
// Include the implementations
#include "wmma/wmma_selector.hpp"
#include "mfma/mfma_selector.hpp"
#include "sparse/sparse_selector.hpp"

View File

@@ -18,6 +18,18 @@ struct PassThroughTransform
}
};
/**
* @struct MmaDefaultPassThroughTransforms
* @brief Implements the default MMA transforms
*/
struct MmaDefaultPassThroughTransforms
{
using ATransform = PassThroughTransform;
using BTransform = PassThroughTransform;
using CTransform = PassThroughTransform;
using DTransform = PassThroughTransform;
};
/**
* @class MmaTransformsDefaultSelector
* @brief Default selector for MmaTransforms based on MmaOp and CompilerTarget
@@ -27,7 +39,10 @@ struct PassThroughTransform
*/
template <typename MmaOp, typename CompilerTarget, typename Enable = void>
// TODO: c++20 template <MmaOpI MmaOp, amdgcn_target_arch_id CompilerTarget, typename Enable = void>
struct MmaTransformsDefaultSelector;
struct MmaTransformsDefaultSelector
{
using SelectedTransforms = MmaDefaultPassThroughTransforms;
};
#if CK_TILE_CONCEPTS

View File

@@ -0,0 +1,177 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "amdgcn_mma.hpp"
#include "mma_pipeline.hpp"
#include "mma_selector.hpp"
#include "mma_transforms.hpp"
#include "mfma/mfma.hpp"
#include "wmma/wmma.hpp"
#include <tuple>
namespace ck_tile::core::arch::mma {
/*! @enum MmaAccumPolicy
* @brief Accumulation order for Mma decomposition
*/
enum struct MmaAccumPolicy
{
// Decomposition and accumulation in row-major fragment order
ROW_MAJOR,
// Decomposition and accumulation in col-major fragment order
COL_MAJOR
};
namespace dense::wavewise::detail {
// TODO: c++20: return MmaPipelineOptionFlags directly
template <bool SwapAB>
constexpr inline int getPipelineFlags()
{
return static_cast<int>(SwapAB ? MmaPipelineOptionFlag::ABSwap : MmaPipelineOptionFlag::NONE);
}
} // namespace dense::wavewise::detail
/**
* @class Mma
* @brief Driver for the wave-tile Mma operation. Given a backend MmaOp implementation
* (e.g., mfma or wmma), this class performs fragment-wise (MmaTile) decomposition to
* matrix-multiply input WaveTiles of (A: WaveTileM x WaveTileK) x (B: WaveTileK x WaveTileN) and
* accumulates results into output WaveTile (C: WaveTileM x WaveTileN).
* @tparam ADataType Data type of input WaveTile A
* @tparam BDataType Data type of input WaveTile B
* @tparam CDataType Data type of input/output WaveTile C (accumulator)
* @tparam WaveTileM Mma WaveTile M dimension
* @tparam WaveTileN Mma WaveTile K dimension
* @tparam WaveTileK Mma WaveTile M dimension
* @tparam AccumPolicy The fragment order of the accum. registers (row or col major frag order)
* @tparam SwapAB Swaps A and B input vectors
* @tparam CompilerTarget The compiler target
* @tparam MmaOp_ Backend wrapper class that will perform the mma op (e.g., mfma or wmma)
* @tparam MmaTransforms The set of transforms to be applied to input/output WaveTiles
* @par This is an example of an Mma decomposition driver class that can be used in a wave-tile
* context. Given a WaveTile size, we can decompose the WaveTile into smaller mma op fragments
* that are natively supported by the hardware (e.g., mfma or wmma). The class also supports
* applying transforms to the input/output frags as needed (e.g., layout conversions, data type
* conversions, etc.). We may also specify the accumulation order (row-major or col-major) for the
* output WaveTile. This is a powerful example of how to build a flexible and reusable mma driver
* that can adapt to different hardware capabilities and requirements.
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
uint32_t WaveTileM,
uint32_t WaveTileN,
uint32_t WaveTileK,
MmaOpFamily OpFamily,
MmaAccumPolicy AccumPolicy = MmaAccumPolicy::ROW_MAJOR,
bool SwapAB = false,
typename CompilerTarget =
decltype(get_compiler_target()), // TODO: c++20 amdgcn_target_arch_id GfxTargetId =
// get_compiler_target(),
typename MmaOp_ =
typename MmaDefaultSelector<ADataType, // TODO: c++20 MmaOpI MmaOp = typename
// MmaDefaultSelector<ADataType,
BDataType,
CDataType,
WaveTileM,
WaveTileN,
WaveTileK,
CompilerTarget,
OpFamily>::SelectedOp,
typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms =
typename MmaTransformsDefaultSelector<MmaOp_, CompilerTarget>::SelectedTransforms>
// clang-format off
struct WaveWiseMmaPipeline : public MmaPipelineBase<dense::wavewise::detail::getPipelineFlags<SwapAB>(),
WaveWiseMmaPipeline<ADataType, BDataType, CDataType, WaveTileM, WaveTileN, WaveTileK, OpFamily, AccumPolicy, SwapAB, CompilerTarget, MmaOp_, MmaTransforms>>
{
using Base = MmaPipelineBase<dense::wavewise::detail::getPipelineFlags<SwapAB>(),
WaveWiseMmaPipeline<ADataType, BDataType, CDataType, WaveTileM, WaveTileN, WaveTileK, OpFamily, AccumPolicy, SwapAB, CompilerTarget, MmaOp_, MmaTransforms>>;
// clang-format on
using MmaOp = MmaOp_;
// Fragment dimensions
constexpr static uint32_t FragM = MmaOp::kM;
constexpr static uint32_t FragN = MmaOp::kN;
constexpr static uint32_t FragK = MmaOp::kK;
// Fragment counts for decomposition
constexpr static uint32_t FragsM = WaveTileM / FragM;
constexpr static uint32_t FragsN = WaveTileN / FragN;
constexpr static uint32_t FragsK = WaveTileK / FragK;
constexpr static uint32_t FragsC = FragsM * FragsN;
// Vector types for packed registers in each fragment
using InternalAVecT = typename MmaOp::AVecType;
using InternalBVecT = typename MmaOp::BVecType;
using InternalCVecT = typename MmaOp::CVecType;
// Buffer types for WaveTiles
using AVecType = InternalAVecT[FragsM][FragsK];
using BVecType = InternalBVecT[FragsN][FragsK];
using CVecType = InternalCVecT[FragsM][FragsN];
// Transforms
using ATransform = typename MmaTransforms::ATransform;
using BTransform = typename MmaTransforms::BTransform;
using CTransform = typename MmaTransforms::CTransform;
using DTransform = typename MmaTransforms::DTransform;
// Sanity checks
static_assert(WaveTileM >= FragM, "WaveTileM must be larger than FragM");
static_assert(WaveTileN >= FragN, "WaveTileN must be larger than FragN");
static_assert(WaveTileK >= FragK, "WaveTileK must be larger than FragK");
static_assert(WaveTileM % FragM == 0u, "WaveTileM must be a multiple of FragM");
static_assert(WaveTileN % FragN == 0u, "WaveTileN must be a multiple of FragN");
static_assert(WaveTileK % FragK == 0u, "WaveTileK must be a multiple of FragK");
template <typename VecTA, typename VecTB, typename VecTC>
CK_TILE_DEVICE static void execImpl(std::tuple<VecTA, VecTB, VecTC>& vecs)
{
auto& [a_frag, b_frag, c_frag] = vecs;
if constexpr(AccumPolicy == MmaAccumPolicy::ROW_MAJOR)
{
// "Row-major" accumulation over the N-dimension fragments first.
// Pseudo code here, but we would basically iterate over the fragments in row-major
// order. We also have to ensure that the incoming vector WaveTiles are converted to
// native vector types before passing to the FragWiseMma exec function.
for(uint32_t bm = 0u; bm < FragsM; ++bm)
{
for(uint32_t bn = 0u; bn < FragsN; ++bn)
{
for(uint32_t bk = 0u; bk < FragsK; ++bk)
{
c_frag[bm][bn] =
MmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]);
}
}
}
}
else if constexpr(AccumPolicy == MmaAccumPolicy::COL_MAJOR)
{
// "Col-major" accumulation over the M-dimension fragments first.
// Pseudo code here, but we would basically iterate over the blocks in col-major order
for(uint32_t bn = 0u; bn < FragsN; ++bn)
{
for(uint32_t bm = 0u; bm < FragsM; ++bm)
{
for(uint32_t bk = 0u; bk < FragsK; ++bk)
{
c_frag[bm][bn] =
MmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]);
}
}
}
}
else
{
static_assert(false);
}
}
};
} // namespace ck_tile::core::arch::mma

View File

@@ -0,0 +1,100 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/arch/mma/mma_pipeline.hpp"
#include "ck_tile/core/arch/mma/mma_selector.hpp"
#include "ck_tile/core/arch/mma/mma_traits.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse_transforms.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include <cstdint>
#include <type_traits>
namespace ck_tile::core::arch::mma {
namespace sparse::detail {
// TODO: c++20: return MmaPipelineOptionFlags directly
constexpr inline int getPipelineFlags()
{
return static_cast<int>(MmaPipelineOptionFlag::COMPRESS_A);
}
} // namespace sparse::detail
template <typename ADataType,
typename BDataType,
typename CDataType,
uint32_t FragM,
uint32_t FragN,
uint32_t FragK,
typename CompilerTarget =
decltype(get_compiler_target()), // TODO: c++20 amdgcn_target_arch_id GfxTargetId =
// get_compiler_target(),
typename MmaOp_ =
typename MmaDefaultSelector<ADataType, // TODO: c++20 MmaOpI MmaOp = typename
// MmaDefaultSelector<ADataType,
BDataType,
CDataType,
FragM,
FragN,
FragK,
CompilerTarget,
MmaOpFamily::SPARSE>::SelectedOp,
typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms =
typename MmaTransformsDefaultSelector<MmaOp_, CompilerTarget>::SelectedTransforms>
// clang-format off
struct SparseMmaPipeline : public MmaPipelineBase<sparse::detail::getPipelineFlags(), SparseMmaPipeline<ADataType, BDataType, CDataType, FragM, FragN, FragK, CompilerTarget, MmaOp_, MmaTransforms>>
{
using Base = MmaPipelineBase<sparse::detail::getPipelineFlags(), SparseMmaPipeline<ADataType, BDataType, CDataType, FragM, FragN, FragK, CompilerTarget, MmaOp_, MmaTransforms>>;
// clang-format on
static_assert(!Base::template hasFlag<MmaPipelineOptionFlag::ABSwap>(),
"Cannot transpose C in sparse intrinsics.");
using MmaOp = MmaOp_; // Expose the selected MmaOp
// Calculate the uncompressed A vector type
struct ExternalAVecCalculator
{
using AVecTraits = vector_traits<typename MmaOp::AVecType>;
static constexpr index_t ASize = AVecTraits::vector_size * MmaOp::kCompressionRatio;
using AVecType = ext_vector_t<typename AVecTraits::scalar_type, ASize>;
};
// Expose caller-side vector types
using AVecType = typename ExternalAVecCalculator::AVecType;
using BVecType = typename MmaOp::BVecType;
using CVecType = typename MmaOp::CVecType;
// Expose internal vector types
using InternalAVecT = typename MmaOp::AVecType;
using InternalBVecT = typename MmaOp::BVecType;
using InternalCVecT = typename MmaOp::CVecType;
// Transforms
using ATransform = typename MmaTransforms::ATransform;
using BTransform = typename MmaTransforms::BTransform;
using CTransform = typename MmaTransforms::CTransform;
using DTransform = typename MmaTransforms::DTransform;
template <typename ATransformResult, typename BTransformResult, typename CTransformResult>
CK_TILE_DEVICE static void
execImpl(std::tuple<ATransformResult, BTransformResult, CTransformResult>& vecs)
{
checkATransformResult<ATransformResult>();
auto& [a_result, b_vec, c_vec] = vecs;
auto& [a_vec, idx] = a_result;
c_vec = MmaOp::exec(a_vec, b_vec, c_vec, idx);
}
private:
// Type check helper - not a device function, so std::declval is available
template <typename ATransformResult>
static constexpr void checkATransformResult()
{
using ExternalAvecRef = std::add_lvalue_reference_t<AVecType>;
static_assert(std::is_same_v<ATransformResult,
decltype(ATransform::exec(std::declval<ExternalAvecRef>()))>);
}
};
} // namespace ck_tile::core::arch::mma

View File

@@ -6,22 +6,101 @@
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
#include "ck_tile/core/arch/mma/mma_transforms.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include <cstdint>
namespace ck_tile::core::arch::mma {
namespace sparse::detail {
/**
* @struct MmaDefaultTransformsSparse
* @brief Compress A vector for 2:4 structured sparsity instruction by moving all non-zero
* elements into lower part of a_vec to half its effective size.
* @param a_vec Vector to be compressed.
* @tparam ADataType The data type of a_vec
* @tparam CompressedSize The target compression size
* @tparam AVec The vector type of a_vec (deduced)
* @return Packed 32bit word containing **CompressedSize** 2bit fields.
* Each field encodes the original position (03) of the corresponding
* nonzero element in the input. If fewer than CompressedSize
* nonzeros are found, remaining fields default to 2 (see below).
*/
template <typename ADataType, index_t CompressedSize, typename AVec>
static CK_TILE_DEVICE int32_t compress_a_impl(AVec& a_vec)
{
// idx holds one 2bit index per output element (total CompressedSize entries).
// It is initialized to the pattern 0b10 for every field. This matches
// what the hardware expects when there are fewer than two nonzero values
// in a 4element group the unused output is treated as coming from slot 2.
// The loop below will clear and set each field as real nonzeros are seen.
int32_t idx = 0;
static_for<0, CompressedSize, 1>{}([&](auto k) { idx |= (2u << (2u * k)); });
static_for<0, CompressedSize / 2, 1>{}([&](auto i) {
ADataType nonzero_elems[2] = {a_vec[i * 4 + 2], a_vec[i * 4 + 3]};
int32_t non_zero_pos = 0;
static_for<0, 4, 1>{}([&](auto j) {
if(static_cast<float>(a_vec[i * 4 + j]) != 0.0f)
{
nonzero_elems[non_zero_pos] = a_vec[i * 4 + j];
// clear the twobit field for this output and insert j
idx &= ~(0b11u << (2u * (i * 2 + non_zero_pos)));
idx |= static_cast<uint32_t>(j) << (2u * (i * 2 + non_zero_pos));
++non_zero_pos;
}
});
a_vec[i * 2] = nonzero_elems[0];
a_vec[i * 2 + 1] = nonzero_elems[1];
});
return idx;
}
} // namespace sparse::detail
/**
* @class SparseCompressTransform
* @brief Performs 2:4 structured sparsity compression to the vector v and produces an index mask.
* @note Returns a tuple of two. The first element is the vector v with the same scalar type but
* its size halved. The second element is the index mask.
*/
template <index_t CompressionRatio>
struct SparseCompressTransform
{
template <typename VecType>
CK_TILE_DEVICE static decltype(auto) exec(VecType& v)
{
using VecTraits = vector_traits<remove_cvref_t<VecType>>;
using ScalarT = typename VecTraits::scalar_type;
static constexpr auto VecN = VecTraits::vector_size;
static constexpr index_t CompressedSize = VecN / CompressionRatio;
using VecCompressed = ext_vector_t<ScalarT, CompressedSize>;
static_assert(VecN % CompressionRatio == 0, "VecN must be divisible by CompressionRatio");
static_assert(CompressedSize > 0, "CompressedSize must be > 0");
const auto idx = sparse::detail::compress_a_impl<ScalarT, CompressedSize>(v);
// TODO c++20: Use bit_cast
return std::tuple<VecCompressed&, int32_t>(
*std::launder(reinterpret_cast<VecCompressed*>(&v)), idx);
}
};
/**
* @class MmaDefaultTransformsSparse
* @brief Implements the default transforms for Sparse
*
* For 2:4 structured sparsity with inline register metadata:
* - ATransform: Pass-through (sparse operands formatted in Exec) TODO!
* - ATransform: 2:4 structured sparsity compression
* - BTransform: Pass-through (sparse operands already formatted)
* - CTransform: Pass-through (input accumulator)
* - DTransform: Pass-through (output accumulator as-is)
*/
template <index_t CompressionRatio>
struct MmaDefaultTransformsSparse
{
using ATransform = PassThroughTransform;
using ATransform = SparseCompressTransform<CompressionRatio>;
using BTransform = PassThroughTransform;
using CTransform = PassThroughTransform;
using DTransform = PassThroughTransform;
@@ -42,7 +121,7 @@ struct MmaTransformsDefaultSelector<MmaOp,
CompilerTarget,
std::enable_if_t<MmaOp::OpFamily == MmaOpFamily::SPARSE>>
{
using SelectedTransforms = MmaDefaultTransformsSparse;
using SelectedTransforms = MmaDefaultTransformsSparse<MmaOp::kCompressionRatio>;
};
} // namespace ck_tile::core::arch::mma

View File

@@ -86,9 +86,11 @@ struct MmaDefaultTransformsGfx12
template <typename MmaOp, typename CompilerTarget>
// TODO: c++20 template <MmaOpI MmaOp, amdgcn_target_arch_id GfxTargetId>
// TODO: c++20 requires
struct MmaTransformsDefaultSelector<MmaOp,
CompilerTarget,
enable_if_target_family_gfx11_t<CompilerTarget>>
struct MmaTransformsDefaultSelector<
MmaOp,
CompilerTarget,
enable_if_all<enable_if_target_family_gfx11_t<CompilerTarget>,
std::enable_if_t<MmaOp::OpFamily == MmaOpFamily::DENSE>>>
{
using SelectedTransforms = MmaDefaultTransformsGfx11;
};
@@ -102,9 +104,11 @@ struct MmaTransformsDefaultSelector<MmaOp,
template <typename MmaOp, typename CompilerTarget>
// TODO: c++20 template <MmaOpI MmaOp, amdgcn_target_arch_id GfxTargetId>
// TODO: c++20 requires
struct MmaTransformsDefaultSelector<MmaOp,
CompilerTarget,
enable_if_target_family_gfx12_t<CompilerTarget>>
struct MmaTransformsDefaultSelector<
MmaOp,
CompilerTarget,
enable_if_all<enable_if_target_family_gfx12_t<CompilerTarget>,
std::enable_if_t<MmaOp::OpFamily == MmaOpFamily::DENSE>>>
{
using SelectedTransforms = MmaDefaultTransformsGfx12;
};

View File

@@ -209,4 +209,21 @@ template <typename ADataType, typename BDataType>
using largest_type_t =
std::conditional_t<sizeof(ADataType) >= sizeof(BDataType), ADataType, BDataType>;
/**
* @brief Type trait to detect whether a type is a @c std::tuple specialization.
* @tparam T The type to inspect.
*/
template <typename T>
struct is_std_tuple : std::false_type
{
};
template <typename... Args>
struct is_std_tuple<std::tuple<Args...>> : std::true_type
{
};
template <typename T>
static constexpr bool is_std_tuple_v = is_std_tuple<T>::value;
} // namespace ck_tile

View File

@@ -5,52 +5,9 @@
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
namespace ck_tile {
/**
* @brief Compress A vector for 2:4 structured sparsity instruction by moving all non-zero
* elements into lower part of a_vec to half its effective size.
* @param a_vec Vector to be compressed.
* @tparam ADataType The data type of a_vec
* @tparam CompressedSize The target compression size
* @tparam AVec The vector type of a_vec (deduced)
* @return Packed 32bit word containing **CompressedSize** 2bit fields.
* Each field encodes the original position (03) of the corresponding
* nonzero element in the input. If fewer than CompressedSize
* nonzeros are found, remaining fields default to 2 (see below).
*/
template <typename ADataType, index_t CompressedSize, typename AVec>
static CK_TILE_DEVICE int32_t compress_a_impl(AVec& a_vec)
{
// idx holds one 2bit index per output element (total CompressedSize entries).
// It is initialized to the pattern 0b10 for every field. This matches
// what the hardware expects when there are fewer than two nonzero values
// in a 4element group the unused output is treated as coming from slot 2.
// The loop below will clear and set each field as real nonzeros are seen.
int32_t idx = 0;
static_for<0, CompressedSize, 1>{}([&](auto k) { idx |= (2 << (2 * k)); });
static_for<0, CompressedSize / 2, 1>{}([&](auto i) {
ADataType nonzero_elems[2] = {a_vec[i * 4 + 2], a_vec[i * 4 + 3]};
int32_t non_zero_pos = 0;
static_for<0, 3, 1>{}([&](auto j) {
if(a_vec[i * 4 + j] != 0.0f)
{
nonzero_elems[non_zero_pos] = a_vec[i * 4 + j];
// clear the twobit field for this output and insert j
idx &= ~(0b11 << 2 * (i * 2 + non_zero_pos));
idx |= j << 2 * (i * 2 + non_zero_pos);
++non_zero_pos;
}
});
a_vec[i * 2] = nonzero_elems[0];
a_vec[i * 2 + 1] = nonzero_elems[1];
});
return idx;
}
template <typename WarpGemmAttribute_>
struct WarpGemmSmfmacImpl
{
@@ -86,10 +43,37 @@ struct WarpGemmSmfmacImpl
return WarpGemmAttribute_::get_num_of_access();
}
template <index_t CompressedSize, typename AVec>
CK_TILE_DEVICE int32_t compress_a_vec(AVec& a_vec)
//----------------------------------------------------------------------------------------------
/// @brief Compress A vector for 2:4 structured sparsity instruction by moving all non-zero
/// elements into lower part of a_vec to half its effective size.
///
/// @param a_vec Vector to be compressed.
///
/// @return Four 2-bit indexes of non-zero elements locations
///
template <typename AVec>
CK_TILE_DEVICE int32_t compress_a(AVec& a_vec) const
{
return compress_a_impl<ADataType, CompressedSize>(a_vec);
int32_t idx = 0b11101110;
static_for<0, 2, 1>{}([&](auto i) {
ADataType nonzero_elems[2] = {a_vec[i * 4 + 2], a_vec[i * 4 + 3]};
int32_t non_zero_pos = 0;
static_for<0, 3, 1>{}([&](auto j) {
if(a_vec[i * 4 + j] != 0.0f)
{
nonzero_elems[non_zero_pos] = a_vec[i * 4 + j];
idx &= ~(0b11 << 2 * (i * 2 + non_zero_pos));
idx |= j << 2 * (i * 2 + non_zero_pos);
++non_zero_pos;
}
});
a_vec[i * 2] = nonzero_elems[0];
a_vec[i * 2 + 1] = nonzero_elems[1];
});
return idx;
}
template <typename CTensor, typename ATensor, typename BTensor, bool post_nop_ = false>
@@ -102,11 +86,10 @@ struct WarpGemmSmfmacImpl
constexpr auto CompressionRatio = WarpGemmAttribute::kCompressionRatio;
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
static constexpr index_t CompressedSize =
ATensor::get_thread_buffer_size() / CompressionRatio;
using AVecCompressed = ext_vector_t<ADataType, CompressedSize>;
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
using AVecCompressed =
ext_vector_t<ADataType, ATensor::get_thread_buffer_size() / CompressionRatio>;
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
constexpr auto I0 = number<0>{};
@@ -114,9 +97,8 @@ struct WarpGemmSmfmacImpl
const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
const int32_t idx = compress_a_vec<CompressedSize>(a_vec);
const int32_t idx = compress_a(a_vec);
static_assert(CompressedSize == 4);
// @TODO can we simply set a_vec_pruned to a_vec[0:3]?
const AVecCompressed a_vec_pruned = {a_vec[0], a_vec[1], a_vec[2], a_vec[3]};

View File

@@ -7,14 +7,15 @@ if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
# TODO: This test is temporarily disabled for cooperation / work planning reasons. Re-enable after merging related work.
# if(GPU_TARGETS MATCHES "gfx9|gfx12")
# add_gtest_executable(test_amdgcn_sparse_mma test_amdgcn_sparse_mma.cpp)
# target_compile_options(test_amdgcn_sparse_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
# endif()
if(GPU_TARGETS MATCHES "gfx9|gfx12")
add_gtest_executable(test_amdgcn_sparse_mma pipeline/test_amdgcn_sparse_mma.cpp)
target_compile_options(test_amdgcn_sparse_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
endif()
if(GPU_TARGETS MATCHES "gfx9")
add_gtest_executable(test_amdgcn_mma test_amdgcn_mma.cpp)
target_compile_options(test_amdgcn_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_amdgcn_wavewise_mma pipeline/test_amdgcn_wavewise_mma.cpp)
target_compile_options(test_amdgcn_wavewise_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
else()
message(DEBUG "Skipping ck_tile_gemm tests for current target")
endif()
@@ -44,3 +45,6 @@ if(GPU_TARGETS MATCHES "gfx12")
target_compile_options(test_amdgcn_mma_layout_gfx12 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
endif()
add_gtest_executable(test_amdgcn_mma_pipeline pipeline/test_amdgcn_mma_pipeline.cpp)
target_compile_options(test_amdgcn_mma_pipeline PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})

View File

@@ -0,0 +1,123 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <cstdint>
#include <functional>
#include <vector>
#include <gtest/gtest.h>
#include "ck_tile/core/arch/arch.hpp"
#include <hip/hip_runtime.h>
#include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/core/numeric/type_convert.hpp"
#include "../get_wave_size_helper.hpp"
template <typename AType_ = ck_tile::fp16_t,
typename BType_ = ck_tile::fp16_t,
typename CType_ = ck_tile::fp32_t,
uint32_t WaveTileM_ = 16,
uint32_t WaveTileN_ = 16,
uint32_t WaveTileK_ = 32>
struct MmaPipelineTest
{
using AType = AType_;
using BType = BType_;
using CType = CType_;
static constexpr auto WaveTileM = WaveTileM_;
static constexpr auto WaveTileN = WaveTileN_;
static constexpr auto WaveTileK = WaveTileK_;
void test_pipeline(std::function<bool(ck_tile::core::arch::amdgcn_target_id)> shouldSkip,
std::function<void(uint32_t, void*, void*, void*, void*)> kernel,
std::function<CType(uint32_t)> getExpected,
std::function<AType(size_t)> aInitializer = nullptr)
{
using namespace ck_tile;
using namespace ck_tile::core::arch;
int devCount;
hipDevice_t dev;
HIP_CHECK_ERROR(hipGetDevice(&dev));
HIP_CHECK_ERROR(hipGetDeviceCount(&devCount));
hipDeviceProp_t devProp;
HIP_CHECK_ERROR(hipGetDeviceProperties(&devProp, dev));
auto currentArchId = hip_device_prop_gcn_arch_name_to_amdgcn_target_id(devProp.gcnArchName);
bool hasDevice = static_cast<bool>(devCount > 0);
int deviceWarpSize = devProp.warpSize;
if(!hasDevice || shouldSkip(currentArchId))
{
GTEST_SKIP() << "No HIP device found. Skipping test.";
}
// WaveTile size, also the expected fragment size (MmaTile) from the selector.
// Note: Actual FragK might be slightly different due to hardware implementation, but the
// test_accum_over_k kernel will loop over the K dimension to ensure that the total K is
// correct.
static constexpr uint32_t FragM = WaveTileM;
static constexpr uint32_t FragN = WaveTileN;
static constexpr uint32_t FragK = WaveTileK;
// The number of elements per thread
uint32_t AElements = FragM * FragK / deviceWarpSize;
uint32_t BElements = FragN * FragK / deviceWarpSize;
uint32_t CElements = FragM * FragN / deviceWarpSize;
uint32_t ASize = AElements * sizeof(AType);
uint32_t BSize = BElements * sizeof(BType);
uint32_t CSize = CElements * sizeof(CType);
// Initialize A (use custom initializer or default all 1's), B to all 1's, C to all 0's
std::vector<AType> h_a(AElements);
if(aInitializer)
{
for(size_t i = 0; i < AElements; ++i)
h_a[i] = aInitializer(i);
}
else
{
std::fill(h_a.begin(), h_a.end(), type_convert<AType>(1));
}
std::vector<BType> h_b(BElements, type_convert<BType>(1));
std::vector<CType> h_c(CElements, type_convert<CType>(0));
std::vector<CType> h_out(CElements, type_convert<CType>(0));
AType* d_a;
BType* d_b;
CType* d_c;
CType* d_out;
HIP_CHECK_ERROR(hipMalloc(&d_a, ASize));
HIP_CHECK_ERROR(hipMalloc(&d_b, BSize));
HIP_CHECK_ERROR(hipMalloc(&d_c, CSize));
HIP_CHECK_ERROR(hipMalloc(&d_out, CSize));
// Copy inputs to device
HIP_CHECK_ERROR(hipMemcpy(d_a, h_a.data(), ASize, hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_b, h_b.data(), BSize, hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice));
const auto wave_size = getDeviceWaveSize();
kernel(wave_size, d_a, d_b, d_c, d_out);
HIP_CHECK_ERROR(hipDeviceSynchronize());
HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_out, CSize, hipMemcpyDeviceToHost));
// Verify output against expected value for all elements
for(size_t i = 0; i < CElements; ++i)
{
EXPECT_NEAR(h_out[i], getExpected(FragK), 1e-3);
}
HIP_CHECK_ERROR(hipFree(d_a));
HIP_CHECK_ERROR(hipFree(d_b));
HIP_CHECK_ERROR(hipFree(d_c));
HIP_CHECK_ERROR(hipFree(d_out));
}
};

View File

@@ -0,0 +1,66 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdint>
#include <gtest/gtest.h>
#include <iostream>
#include <numeric>
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/mma_pipeline.hpp"
namespace {
using namespace ck_tile::core::arch::mma;
}
TEST(MmaPipelineOptionFlagsTests, ConversionTests)
{
MmaPipelineOptionFlags flags_0{};
MmaPipelineOptionFlags flags_1{MmaPipelineOptionFlag::ABSwap};
MmaPipelineOptionFlags flags_2{MmaPipelineOptionFlag::COMPRESS_A};
MmaPipelineOptionFlags flags_3{0b11};
EXPECT_TRUE(flags_0.testFlag(MmaPipelineOptionFlag::NONE));
EXPECT_FALSE(flags_0.testFlag(MmaPipelineOptionFlag::ABSwap));
EXPECT_FALSE(flags_0.testFlag(MmaPipelineOptionFlag::COMPRESS_A));
EXPECT_TRUE(flags_1.testFlag(MmaPipelineOptionFlag::ABSwap));
EXPECT_FALSE(flags_1.testFlag(MmaPipelineOptionFlag::NONE));
EXPECT_FALSE(flags_1.testFlag(MmaPipelineOptionFlag::COMPRESS_A));
EXPECT_TRUE(flags_2.testFlag(MmaPipelineOptionFlag::COMPRESS_A));
EXPECT_FALSE(flags_2.testFlag(MmaPipelineOptionFlag::NONE));
EXPECT_FALSE(flags_2.testFlag(MmaPipelineOptionFlag::ABSwap));
EXPECT_TRUE(flags_3.testFlag(MmaPipelineOptionFlag::COMPRESS_A));
EXPECT_TRUE(flags_3.testFlag(MmaPipelineOptionFlag::ABSwap));
EXPECT_FALSE(flags_3.testFlag(MmaPipelineOptionFlag::NONE));
}
TEST(MmaPipelineOptionFlagsTests, OperatorsTests)
{
MmaPipelineOptionFlags flags{};
EXPECT_TRUE(flags.testFlag(MmaPipelineOptionFlag::NONE));
flags |= MmaPipelineOptionFlag::ABSwap;
EXPECT_FALSE(flags.testFlag(MmaPipelineOptionFlag::NONE));
EXPECT_TRUE(flags.testFlag(MmaPipelineOptionFlag::ABSwap));
flags |= MmaPipelineOptionFlag::COMPRESS_A;
EXPECT_FALSE(flags.testFlag(MmaPipelineOptionFlag::NONE));
EXPECT_TRUE(flags.testFlag(MmaPipelineOptionFlag::ABSwap));
EXPECT_TRUE(flags.testFlag(MmaPipelineOptionFlag::COMPRESS_A));
flags &= MmaPipelineOptionFlag::COMPRESS_A;
EXPECT_FALSE(flags.testFlag(MmaPipelineOptionFlag::NONE));
EXPECT_FALSE(flags.testFlag(MmaPipelineOptionFlag::ABSwap));
EXPECT_TRUE(flags.testFlag(MmaPipelineOptionFlag::COMPRESS_A));
EXPECT_FALSE((~flags).testFlag(MmaPipelineOptionFlag::NONE));
EXPECT_TRUE((~flags).testFlag(MmaPipelineOptionFlag::ABSwap));
EXPECT_FALSE((~flags).testFlag(MmaPipelineOptionFlag::COMPRESS_A));
}

View File

@@ -0,0 +1,523 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdint>
#include <gtest/gtest.h>
#include <iostream>
#include <numeric>
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
#include "ck_tile/core/arch/mma/mma_selector.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp"
#include <hip/hip_runtime.h>
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/core/arch/mma/mma_traits.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "pipeline_tests_helper.hpp"
using namespace ck_tile;
using namespace ck_tile::core::arch;
using namespace ck_tile::core::arch::mma;
using CompilerTargetGfx950 = decltype(make_amdgcn_gfx9_target<amdgcn_target_id::GFX950>());
TEST(SparseMMATrait, SparseMfmaGfx950Specialization)
{
// Test fp16 → fp32 sparse MFMA for GFX950 (16x16x32)
using TestSparseMfma16x16 = amdgcn_mma<fp16_t,
fp16_t,
fp32_t,
16u,
16u,
32u,
DefaultSparseMfmaCtrlFlags,
CompilerTargetGfx950,
MmaOpFamily::SPARSE>;
static_assert(std::is_same_v<typename TestSparseMfma16x16::OpType, MfmaOp> &&
TestSparseMfma16x16::OpFamily == MmaOpFamily::SPARSE,
"GFX950 sparse 16x16x32 should have SparseMFMAOp type");
static_assert(is_mma_op_of_family_v<MmaOpFamily::SPARSE, TestSparseMfma16x16>,
"GFX950 sparse 16x16x32 should be detected as Sparse");
std::cout << "GFX950 sparse MFMA specialization is correct" << std::endl;
}
TEST(SparseMMATrait, MmaOpTraitsIntegration)
{
// Create a sparse MMA op (16x16x32 fp16 specialization)
using TestSparseMmma = amdgcn_mma<fp16_t,
fp16_t,
fp32_t,
16u,
16u,
32u,
DefaultSparseMfmaCtrlFlags,
CompilerTargetGfx950,
MmaOpFamily::SPARSE>;
// Get its traits
using TestTraits = MmaOpTraits<TestSparseMmma>;
// Verify trait detection
static_assert(TestTraits::IsSparse, "Sparse MMA should be detected as sparse");
static_assert(TestTraits::IsSupported, "Sparse MMA specialization should be supported");
static_assert(TestTraits::IsMfma, "Sparse MFMA should be detected as MFMA");
static_assert(!TestTraits::IsWmma, "Sparse MFMA should not be detected as WMMA");
std::cout << "MmaOpTraits correctly integrates sparse operations" << std::endl;
}
TEST(SparseMMATrait, TestConceptRequirements)
{
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
using TestSparseMmma = amdgcn_mma<fp16_t,
fp16_t,
fp32_t,
16u,
16u,
32u,
DefaultSparseMfmaCtrlFlags,
CompilerTargetGfx950,
MmaOpFamily::SPARSE>;
static_assert(MmaOpI<TestSparseMmma>);
#else
GTEST_SKIP() << "Not compiled with concepts. Skipping test.";
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
}
TEST(SparseMMATrait, DenseVsSparseDistinction)
{
// Dense MFMA from mfma/mfma_gfx9.hpp
using DenseMfma = amdgcn_mma<fp16_t,
fp16_t,
fp32_t,
16u,
16u,
32u,
DefaultMfmaCtrlFlags,
CompilerTargetGfx950,
MmaOpFamily::DENSE>;
// Sparse MFMA on GFX950
using SparseMfma = amdgcn_mma<fp16_t,
fp16_t,
fp32_t,
16u,
16u,
32u,
DefaultSparseMfmaCtrlFlags,
CompilerTargetGfx950,
MmaOpFamily::SPARSE>;
// Verify they have different operation types
static_assert(std::is_same_v<typename DenseMfma::OpType, typename SparseMfma::OpType> &&
DenseMfma::OpFamily != SparseMfma::OpFamily,
"Dense and Sparse MFMA should have the same OpType tags and different OpFamily");
// Verify traits correctly identify them
static_assert(MmaOpTraits<DenseMfma>::IsMfma && MmaOpTraits<DenseMfma>::IsDense &&
!MmaOpTraits<DenseMfma>::IsSparse && !MmaOpTraits<DenseMfma>::IsScale &&
MmaOpTraits<DenseMfma>::IsSupported,
"Dense MFMA should be identified correctly");
static_assert(MmaOpTraits<SparseMfma>::IsSparse && MmaOpTraits<SparseMfma>::IsMfma &&
!MmaOpTraits<SparseMfma>::IsDense && !MmaOpTraits<SparseMfma>::IsScale &&
MmaOpTraits<SparseMfma>::IsSupported,
"Sparse MFMA should be identified correctly");
std::cout << "Dense and sparse MMA operations are correctly distinguished" << std::endl;
}
TEST(SparseMMATrait, SparseSelector)
{
static_for<1, 33, 1>{}([](auto i) {
using Selected = typename MmaDefaultSelector<fp16_t,
fp16_t,
fp32_t,
static_cast<uint32_t>(i),
static_cast<uint32_t>(i),
static_cast<uint32_t>(2 * i),
CompilerTargetGfx950,
MmaOpFamily::SPARSE>::SelectedOp;
static constexpr bool isValid = (i == 16) || (i == 32);
if constexpr(isValid)
{
// Selector should pick a sparse MFMA implementation
static_assert(MmaOpTraits<Selected>::IsSparse);
static_assert(MmaOpTraits<Selected>::IsMfma);
static_assert(MmaOpTraits<Selected>::IsSupported);
static_assert((std::is_same<typename Selected::OpType, MfmaOp>::value));
}
else
{
// Selector should pick the unsupported pass through
static_assert(!MmaOpTraits<Selected>::IsSupported);
}
});
}
template <typename AType,
typename BType,
typename CType,
uint32_t WaveTileM,
uint32_t WaveTileN,
uint32_t WaveTileK>
__global__ void test_sparse_accum_over_k(void* a, void* b, void* c, void* out)
{
using Pipeline = SparseMmaPipeline<AType, BType, CType, WaveTileM, WaveTileN, WaveTileK>;
using AVecType = typename Pipeline::AVecType;
using BVecType = typename Pipeline::BVecType;
using CVecType = typename Pipeline::CVecType;
static constexpr uint32_t kIters = WaveTileK / Pipeline::MmaOp::kK;
// Initialize the accumulator
CVecType result = *reinterpret_cast<CVecType*>(c);
// Accumulate input AxB over WaveTileK/FragK iterations
for(uint32_t i = 0; i < kIters; ++i)
{
result = Pipeline::exec(
*reinterpret_cast<AVecType*>(a), *reinterpret_cast<BVecType*>(b), result);
}
*reinterpret_cast<CVecType*>(out) = result;
}
// Live test on real hardware for sparse selection and execution.
TEST(SparseMMATrait, MmaSelector_Sparse_F16_F16_F32_16x16x32_Real)
{
MmaPipelineTest<> test;
const auto should_skip = [](amdgcn_target_id currentArchId) {
bool isSupportedWmma = (currentArchId >= amdgcn_target_id::GFX1200) &&
(currentArchId <= amdgcn_target_id::GFX12_GENERIC);
bool isSupportedMfma = (currentArchId >= amdgcn_target_id::GFX942) &&
(currentArchId <= amdgcn_target_id::GFX950);
return ((currentArchId == amdgcn_target_id::HOST) || !(isSupportedWmma || isSupportedMfma));
};
const std::function<fp32_t(uint32_t)> validator = [](uint32_t waveTileK) {
return static_cast<fp32_t>(waveTileK) / 2;
};
const auto kernel = [](uint32_t waveSize, void* a, void* b, void* c, void* out) {
test_sparse_accum_over_k<MmaPipelineTest<>::AType,
MmaPipelineTest<>::BType,
MmaPipelineTest<>::CType,
MmaPipelineTest<>::WaveTileM,
MmaPipelineTest<>::WaveTileN,
MmaPipelineTest<>::WaveTileK><<<1, waveSize>>>(a, b, c, out);
};
// Initialize A with 2:4 structured sparsity pattern: {1, 0, 1, 0, ...}
// This ensures the sparse compression transform is actually exercised —
// a no-op or broken compression would pass zeros through, causing incorrect results.
const std::function<fp16_t(size_t)> sparseAInit = [](size_t i) -> fp16_t {
return (i % 2 == 0) ? type_convert<fp16_t>(1) : type_convert<fp16_t>(0);
};
test.test_pipeline(should_skip, kernel, validator, sparseAInit);
}
template <uint32_t CompressionRatio, typename Vec>
__global__ void test_sparse_transform(void* a, void* idx)
{
using ResultT =
decltype(SparseCompressTransform<CompressionRatio>::exec(*static_cast<Vec*>(a)));
using FirstT = std::tuple_element_t<0, ResultT>;
const auto& [vec, i] = SparseCompressTransform<CompressionRatio>::exec(*static_cast<Vec*>(a));
*reinterpret_cast<remove_cvref_t<FirstT>*>(a) = vec;
*reinterpret_cast<int32_t*>(idx) = i;
}
// Generalized helper: runs the sparse transform kernel and verifies compressed output and index.
template <int NUM, int RATIO, typename Type>
void sparse_transform_verify(const std::vector<Type>& input,
const std::vector<Type>& expected_output,
int32_t expected_idx)
{
static_assert(RATIO == 2, "Extend functionality if other ratio is used.");
ASSERT_EQ(static_cast<int>(input.size()), NUM);
ASSERT_EQ(static_cast<int>(expected_output.size()), NUM / RATIO);
int devCount;
hipDevice_t dev;
HIP_CHECK_ERROR(hipGetDevice(&dev));
HIP_CHECK_ERROR(hipGetDeviceCount(&devCount));
hipDeviceProp_t devProp;
HIP_CHECK_ERROR(hipGetDeviceProperties(&devProp, dev));
auto currentArchId = hip_device_prop_gcn_arch_name_to_amdgcn_target_id(devProp.gcnArchName);
bool hasDevice = static_cast<bool>(devCount > 0);
// TODO: c++20 add check for arch id
if(!hasDevice || (currentArchId == amdgcn_target_id::HOST))
{
GTEST_SKIP() << "No HIP device found. Skipping test.";
}
float* d_v;
int32_t* d_idx;
static constexpr auto Size = sizeof(Type) * NUM;
HIP_CHECK_ERROR(hipMalloc(&d_v, Size));
HIP_CHECK_ERROR(hipMalloc(&d_idx, sizeof(int32_t)));
// Copy inputs to device
HIP_CHECK_ERROR(hipMemcpy(d_v, input.data(), Size, hipMemcpyHostToDevice));
test_sparse_transform<RATIO, ext_vector_t<Type, NUM>><<<1, 32>>>(d_v, d_idx);
HIP_CHECK_ERROR(hipDeviceSynchronize());
std::vector<Type> h_out(NUM / RATIO, static_cast<Type>(0));
HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_v, Size / RATIO, hipMemcpyDeviceToHost));
int32_t h_idx;
HIP_CHECK_ERROR(hipMemcpy(&h_idx, d_idx, sizeof(int32_t), hipMemcpyDeviceToHost));
EXPECT_EQ(h_idx, expected_idx) << "Index mask mismatch";
for(int i = 0; i < NUM / RATIO; ++i)
{
EXPECT_EQ(h_out[i], expected_output[i]) << "Output mismatch at position " << i;
}
// Semantic index validation: each 2-bit field in h_idx encodes the original
// slot (03) within the group of 4 that the corresponding compressed element
// came from. Verify that the index is consistent with input and output.
//
// Note: when a group has fewer than 2 non-zeros, unused output slots contain
// initialization values (from nonzero_elems init) that don't correspond to the
// default index (slot 2). We only validate entries where the index was explicitly
// set, i.e. where input[slot] is non-zero.
constexpr int CompressedSize = NUM / RATIO;
for(int i = 0; i < CompressedSize; ++i)
{
int slot = (h_idx >> (2 * i)) & 0b11;
int group = i / 2;
Type input_at_slot = input[group * 4 + slot];
// Only check when input at the indexed slot is non-zero (explicitly assigned)
// or when both are zero (consistent default for all-zero groups).
if(static_cast<float>(input_at_slot) != 0.0f || static_cast<float>(h_out[i]) == 0.0f)
{
EXPECT_EQ(h_out[i], input_at_slot)
<< "Index field " << i << " points to slot " << slot << " in group " << group
<< " but output[" << i << "] != input[" << (group * 4 + slot) << "]";
}
}
HIP_CHECK_ERROR(hipFree(d_v));
HIP_CHECK_ERROR(hipFree(d_idx));
}
// Helper: build expected index from a per-group 4-bit pattern, repeated for all groups.
// Each group of 4 input elements contributes 2 compressed elements → 2 x 2-bit index fields = 4
// bits.
static int32_t build_repeated_group_idx(int num_groups, int32_t group_bits_4)
{
int32_t idx = 0;
for(int g = 0; g < num_groups; ++g)
idx |= (group_bits_4 << (4 * g));
return idx;
}
// Helper: build expected index from alternating even/odd 4-bit group patterns.
static int32_t build_alternating_group_idx(int num_groups, int32_t even_bits_4, int32_t odd_bits_4)
{
int32_t idx = 0;
for(int g = 0; g < num_groups; ++g)
idx |= ((g % 2 == 0 ? even_bits_4 : odd_bits_4) << (4 * g));
return idx;
}
// 1. Basic correctness: valid divisible sizes
// Input pattern: {1, 0, 3, 0, 5, 0, 7, 0, ...} → non-zeros at slots 0,2
// Group idx pattern: field0=0b00 (slot 0), field1=0b10 (slot 2) → 0b1000
template <int NUM, int RATIO, typename Type>
void sparse_transform_test_case()
{
std::vector<Type> v(NUM);
for(int i = 0; i < NUM; ++i)
{
v[i] = i % 2 == 0 ? i + 1 : 0;
}
std::vector<Type> expected_out(NUM / RATIO);
for(int i = 0; i < NUM / RATIO; ++i)
{
expected_out[i] = v[i * 2];
}
int32_t expected_idx = build_repeated_group_idx(NUM / 4, 0b1000);
sparse_transform_verify<NUM, RATIO, Type>(v, expected_out, expected_idx);
}
TEST(SparseTransformsTest, ValidCompressionRatio)
{
// TODO: extend those when new sparse builtins are
// introduced and use different type combinations
sparse_transform_test_case<8, 2, fp16_t>();
sparse_transform_test_case<16, 2, fp16_t>();
sparse_transform_test_case<32, 2, fp16_t>();
}
// All-zero input: no non-zeros in any group of 4.
// Each output pair defaults to {a_vec[slot2], a_vec[slot3]} = {0, 0},
// and the index uses default slot-2 encoding (0b10) for every 2-bit field.
// Group idx pattern: 0b1010
template <int NUM>
void sparse_transform_all_zero()
{
using T = fp16_t;
std::vector<T> input(NUM, static_cast<T>(0));
std::vector<T> expected_output(NUM / 2, static_cast<T>(0));
int32_t expected_idx = build_repeated_group_idx(NUM / 4, 0b1010);
sparse_transform_verify<NUM, 2, T>(input, expected_output, expected_idx);
}
TEST(SparseTransformsTest, AllZeroInput)
{
sparse_transform_all_zero<8>();
sparse_transform_all_zero<16>();
sparse_transform_all_zero<32>();
}
// Single non-zero per group of 4 (at slot 3).
// nonzero_elems initializes to {a_vec[slot2]=0, a_vec[slot3]=V}.
// Only j=3 triggers: nonzero_elems[0]=V, field0=0b11, pos becomes 1.
// nonzero_elems[1] keeps its init V. Output: {V, V}.
// Group idx pattern: field0=0b11, field1=0b10 (default) → 0b1011
template <int NUM>
void sparse_transform_single_nonzero()
{
using T = fp16_t;
std::vector<T> input(NUM, static_cast<T>(0));
std::vector<T> expected_output(NUM / 2);
for(int g = 0; g < NUM / 4; ++g)
{
T val = static_cast<T>(g + 5);
input[g * 4 + 3] = val;
expected_output[g * 2] = val;
expected_output[g * 2 + 1] = val;
}
int32_t expected_idx = build_repeated_group_idx(NUM / 4, 0b1011);
sparse_transform_verify<NUM, 2, T>(input, expected_output, expected_idx);
}
TEST(SparseTransformsTest, SingleNonZeroPerGroup)
{
sparse_transform_single_nonzero<8>();
sparse_transform_single_nonzero<16>();
sparse_transform_single_nonzero<32>();
}
// Non-zeros at slots 1 and 3 in each group.
// Input: {0, a, 0, b, ...}. Output: {a, b, ...}.
// Group idx pattern: field0=0b01 (slot 1), field1=0b11 (slot 3) → 0b1101
template <int NUM>
void sparse_transform_slots_1_and_3()
{
using T = fp16_t;
std::vector<T> input(NUM, static_cast<T>(0));
std::vector<T> expected_output(NUM / 2);
for(int g = 0; g < NUM / 4; ++g)
{
T a = static_cast<T>(g * 2 + 3);
T b = static_cast<T>(g * 2 + 4);
input[g * 4 + 1] = a;
input[g * 4 + 3] = b;
expected_output[g * 2] = a;
expected_output[g * 2 + 1] = b;
}
int32_t expected_idx = build_repeated_group_idx(NUM / 4, 0b1101);
sparse_transform_verify<NUM, 2, T>(input, expected_output, expected_idx);
}
TEST(SparseTransformsTest, NonZerosAtSlots1And3)
{
sparse_transform_slots_1_and_3<8>();
sparse_transform_slots_1_and_3<16>();
sparse_transform_slots_1_and_3<32>();
}
// Non-zeros at slots 0 and 3 in each group (non-adjacent).
// Input: {a, 0, 0, b, ...}. Output: {a, b, ...}.
// Group idx pattern: field0=0b00 (slot 0), field1=0b11 (slot 3) → 0b1100
template <int NUM>
void sparse_transform_slots_0_and_3()
{
using T = fp16_t;
std::vector<T> input(NUM, static_cast<T>(0));
std::vector<T> expected_output(NUM / 2);
for(int g = 0; g < NUM / 4; ++g)
{
T a = static_cast<T>(g * 2 + 2);
T b = static_cast<T>(g * 2 + 3);
input[g * 4] = a;
input[g * 4 + 3] = b;
expected_output[g * 2] = a;
expected_output[g * 2 + 1] = b;
}
int32_t expected_idx = build_repeated_group_idx(NUM / 4, 0b1100);
sparse_transform_verify<NUM, 2, T>(input, expected_output, expected_idx);
}
TEST(SparseTransformsTest, NonZerosAtSlots0And3)
{
sparse_transform_slots_0_and_3<8>();
sparse_transform_slots_0_and_3<16>();
sparse_transform_slots_0_and_3<32>();
}
// Mixed sparsity pattern: even groups have non-zeros at slots 0,2; odd groups at slots 1,3.
// Even group idx: field0=0b00, field1=0b10 → 0b1000
// Odd group idx: field0=0b01, field1=0b11 → 0b1101
template <int NUM>
void sparse_transform_mixed()
{
using T = fp16_t;
std::vector<T> input(NUM, static_cast<T>(0));
std::vector<T> expected_output(NUM / 2);
for(int g = 0; g < NUM / 4; ++g)
{
T a = static_cast<T>(g * 2 + 1);
T b = static_cast<T>(g * 2 + 2);
if(g % 2 == 0)
{
// Slots 0, 2
input[g * 4] = a;
input[g * 4 + 2] = b;
}
else
{
// Slots 1, 3
input[g * 4 + 1] = a;
input[g * 4 + 3] = b;
}
expected_output[g * 2] = a;
expected_output[g * 2 + 1] = b;
}
int32_t expected_idx = build_alternating_group_idx(NUM / 4, 0b1000, 0b1101);
sparse_transform_verify<NUM, 2, T>(input, expected_output, expected_idx);
}
TEST(SparseTransformsTest, MixedSparsityPattern)
{
sparse_transform_mixed<8>();
sparse_transform_mixed<16>();
sparse_transform_mixed<32>();
}

View File

@@ -0,0 +1,93 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
#include "ck_tile/core/arch/mma/mma_wavewise.hpp"
#include "pipeline_tests_helper.hpp"
#include <memory>
using namespace ck_tile;
using namespace ck_tile::core::arch;
using namespace ck_tile::core::arch::mma;
template <typename AType,
typename BType,
typename CType,
uint32_t WaveTileM,
uint32_t WaveTileN,
uint32_t WaveTileK,
bool CTranspose>
__global__ void test_wavewise_pipeline(void* a, void* b, void* c, void* out)
{
using CompilerTarget = decltype(get_compiler_target());
using Pipeline = WaveWiseMmaPipeline<AType,
BType,
CType,
WaveTileM,
WaveTileN,
WaveTileK,
MmaOpFamily::DENSE,
MmaAccumPolicy::ROW_MAJOR,
CTranspose,
CompilerTarget>;
using AVecType = typename Pipeline::AVecType;
using BVecType = typename Pipeline::BVecType;
using CVecType = typename Pipeline::CVecType;
auto result = Pipeline::exec(*reinterpret_cast<AVecType*>(a),
*reinterpret_cast<BVecType*>(b),
*reinterpret_cast<CVecType*>(c));
if constexpr(MmaOpTraits<typename Pipeline::MmaOp>::IsSupported)
{
// When the MmaOp is Unsupported (default) it returns the CVecType by value
// so this cast is impossible...
__builtin_memcpy(out, static_cast<const void*>(result), sizeof(CVecType));
}
}
namespace {
const auto should_skip = [](amdgcn_target_id currentArchId) {
bool isSupportedWmma = false;
bool isSupportedMfma =
(currentArchId >= amdgcn_target_id::GFX942) && (currentArchId <= amdgcn_target_id::GFX950);
return ((currentArchId == amdgcn_target_id::HOST) || !(isSupportedWmma || isSupportedMfma));
};
const std::function<fp32_t(uint32_t)> validator = [](uint32_t waveTileK) {
return static_cast<fp32_t>(waveTileK);
};
} // namespace
TEST(WaveWiseMmaPipeline, testKIter)
{
MmaPipelineTest<> test;
const auto kernel = [](uint32_t waveSize, void* a, void* b, void* c, void* out) {
test_wavewise_pipeline<MmaPipelineTest<>::AType,
MmaPipelineTest<>::BType,
MmaPipelineTest<>::CType,
MmaPipelineTest<>::WaveTileM,
MmaPipelineTest<>::WaveTileN,
MmaPipelineTest<>::WaveTileK,
false><<<1, waveSize>>>(a, b, c, out);
};
test.test_pipeline(should_skip, kernel, validator);
}
TEST(WaveWiseMmaPipeline, testKIterSwapAB)
{
MmaPipelineTest<> test;
const auto kernel = [](uint32_t waveSize, void* a, void* b, void* c, void* out) {
test_wavewise_pipeline<MmaPipelineTest<>::AType,
MmaPipelineTest<>::BType,
MmaPipelineTest<>::CType,
MmaPipelineTest<>::WaveTileM,
MmaPipelineTest<>::WaveTileN,
MmaPipelineTest<>::WaveTileK,
true><<<1, waveSize>>>(a, b, c, out);
};
test.test_pipeline(should_skip, kernel, validator);
}

View File

@@ -7,7 +7,7 @@
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
#include "ck_tile/core/arch/mma/mma_selector.hpp"
#include "ck_tile/core/arch/mma/mma.hpp"
#include "ck_tile/core/arch/mma/mma_wavewise.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/host/hip_check_error.hpp"

View File

@@ -1,271 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <gtest/gtest.h>
#include <iostream>
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
#include "ck_tile/core/arch/mma/mma_selector.hpp"
#include <hip/hip_runtime.h>
#include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/core/arch/mma/mma_traits.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "get_wave_size_helper.hpp"
using namespace ck_tile;
using namespace ck_tile::core::arch;
using namespace ck_tile::core::arch::mma;
using CompilerTargetGfx950 = decltype(make_amdgcn_gfx9_target<amdgcn_target_id::GFX950>());
TEST(SparseMMATrait, SparseMfmaGfx950Specialization)
{
// Test fp16 → fp32 sparse MFMA for GFX950 (16x16x32)
using TestSparseMfma16x16 = amdgcn_mma<fp16_t,
fp16_t,
fp32_t,
16u,
16u,
32u,
DefaultSparseMfmaCtrlFlags,
CompilerTargetGfx950,
MmaOpFamily::SPARSE>;
static_assert(std::is_same_v<typename TestSparseMfma16x16::OpType, MfmaOp> &&
TestSparseMfma16x16::OpFamily == MmaOpFamily::SPARSE,
"GFX950 sparse 16x16x32 should have SparseMFMAOp type");
static_assert(is_mma_op_of_family_v<MmaOpFamily::SPARSE, TestSparseMfma16x16>,
"GFX950 sparse 16x16x32 should be detected as Sparse");
std::cout << "GFX950 sparse MFMA specialization is correct" << std::endl;
}
TEST(SparseMMATrait, MmaOpTraitsIntegration)
{
// Create a sparse MMA op (16x16x32 fp16 specialization)
using TestSparseMmma = amdgcn_mma<fp16_t,
fp16_t,
fp32_t,
16u,
16u,
32u,
DefaultSparseMfmaCtrlFlags,
CompilerTargetGfx950,
MmaOpFamily::SPARSE>;
// Get its traits
using TestTraits = MmaOpTraits<TestSparseMmma>;
// Verify trait detection
static_assert(TestTraits::IsSparse, "Sparse MMA should be detected as sparse");
static_assert(TestTraits::IsSupported, "Sparse MMA specialization should be supported");
static_assert(TestTraits::IsMfma, "Sparse MFMA should be detected as MFMA");
static_assert(!TestTraits::IsWmma, "Sparse MFMA should not be detected as WMMA");
std::cout << "MmaOpTraits correctly integrates sparse operations" << std::endl;
}
TEST(SparseMMATrait, DenseVsSparseDistinction)
{
// Dense MFMA from mfma/mfma_gfx9.hpp
using DenseMfma = amdgcn_mma<fp16_t,
fp16_t,
fp32_t,
16u,
16u,
32u,
DefaultMfmaCtrlFlags,
CompilerTargetGfx950,
MmaOpFamily::DENSE>;
// Sparse MFMA on GFX950
using SparseMfma = amdgcn_mma<fp16_t,
fp16_t,
fp32_t,
16u,
16u,
32u,
DefaultSparseMfmaCtrlFlags,
CompilerTargetGfx950,
MmaOpFamily::SPARSE>;
// Verify they have different operation types
static_assert(std::is_same_v<typename DenseMfma::OpType, typename SparseMfma::OpType> &&
DenseMfma::OpFamily != SparseMfma::OpFamily,
"Dense and Sparse MFMA should have the same OpType tags and different OpFamily");
// Verify traits correctly identify them
static_assert(MmaOpTraits<DenseMfma>::IsMfma && MmaOpTraits<DenseMfma>::IsDense &&
!MmaOpTraits<DenseMfma>::IsSparse && !MmaOpTraits<DenseMfma>::IsScale &&
MmaOpTraits<DenseMfma>::IsSupported,
"Dense MFMA should be identified correctly");
static_assert(MmaOpTraits<SparseMfma>::IsSparse && MmaOpTraits<SparseMfma>::IsMfma &&
!MmaOpTraits<SparseMfma>::IsDense && !MmaOpTraits<SparseMfma>::IsScale &&
MmaOpTraits<SparseMfma>::IsSupported,
"Sparse MFMA should be identified correctly");
std::cout << "Dense and sparse MMA operations are correctly distinguished" << std::endl;
}
TEST(SparseMMATrait, SparseSelector)
{
static_for<1, 33, 1>{}([](auto i) {
using Selected = typename MmaDefaultSelector<fp16_t,
fp16_t,
fp32_t,
static_cast<uint32_t>(i),
static_cast<uint32_t>(i),
static_cast<uint32_t>(2 * i),
CompilerTargetGfx950,
MmaOpFamily::SPARSE>::SelectedOp;
static constexpr bool isValid = (i == 16) || (i == 32);
if constexpr(isValid)
{
// Selector should pick a sparse MFMA implementation
static_assert(MmaOpTraits<Selected>::IsSparse);
static_assert(MmaOpTraits<Selected>::IsMfma);
static_assert(MmaOpTraits<Selected>::IsSupported);
static_assert((std::is_same<typename Selected::OpType, MfmaOp>::value));
}
else
{
// Selector should pick the unsupported pass through
static_assert(!MmaOpTraits<Selected>::IsSupported);
}
});
}
template <typename AType,
typename BType,
typename CType,
uint32_t WaveTileM,
uint32_t WaveTileN,
uint32_t WaveTileK>
__global__ void test_sparse_accum_over_k(void* a, void* b, void* c, void* out)
{
using CompilerTarget = decltype(get_compiler_target());
using Selector = MmaDefaultSelector<AType,
BType,
CType,
WaveTileM,
WaveTileN,
WaveTileK,
CompilerTarget,
MmaOpFamily::SPARSE>;
using MmaOp = typename Selector::SelectedOp;
using CVecType = typename MmaOp::CVecType;
static constexpr uint32_t kIters = WaveTileK / MmaOp::kK;
// Initialize the accumulator
CVecType result = *reinterpret_cast<typename MmaOp::CVecType*>(c);
// Accumulate input AxB over WaveTileK/FragK iterations
for(uint32_t i = 0; i < kIters; ++i)
{
result = MmaOp::exec(*reinterpret_cast<typename MmaOp::AVecType*>(a),
*reinterpret_cast<typename MmaOp::BVecType*>(b),
result);
}
*reinterpret_cast<typename MmaOp::CVecType*>(out) = result;
}
// Live test on real hardware for sparse selection and execution.
TEST(SparseMMATrait, MmaSelector_Sparse_F16_F16_F32_16x16x32_Real)
{
int devCount;
hipDevice_t dev;
HIP_CHECK_ERROR(hipGetDevice(&dev));
HIP_CHECK_ERROR(hipGetDeviceCount(&devCount));
hipDeviceProp_t devProp;
HIP_CHECK_ERROR(hipGetDeviceProperties(&devProp, dev));
auto currentArchId = hip_device_prop_gcn_arch_name_to_amdgcn_target_id(devProp.gcnArchName);
bool hasDevice = static_cast<bool>(devCount > 0);
int deviceWarpSize = devProp.warpSize;
bool isSupportedWmma = (currentArchId >= amdgcn_target_id::GFX1200) &&
(currentArchId <= amdgcn_target_id::GFX12_GENERIC);
bool isSupportedMfma =
(currentArchId >= amdgcn_target_id::GFX942) && (currentArchId <= amdgcn_target_id::GFX950);
// TODO: c++20 add check for arch id
if(!hasDevice || (currentArchId == amdgcn_target_id::HOST) ||
!(isSupportedWmma || isSupportedMfma))
{
GTEST_SKIP() << "No HIP device found. Skipping test.";
}
using AType = fp16_t;
using BType = fp16_t;
using CType = fp32_t;
// WaveTile size, also the expected fragment size (MmaTile) from the selector.
// Note: Actual FragK might be slightly different due to hardware implementation, but the
// test_accum_over_k kernel will loop over the K dimension to ensure that the total K is
// correct.
static constexpr uint32_t WaveTileM = 16;
static constexpr uint32_t WaveTileN = 16;
static constexpr uint32_t WaveTileK = 32;
static constexpr uint32_t FragM = WaveTileM;
static constexpr uint32_t FragN = WaveTileN;
static constexpr uint32_t FragK = WaveTileK;
// The number of elements per thread
uint32_t AElements = FragM * FragK / deviceWarpSize;
uint32_t BElements = FragN * FragK / deviceWarpSize;
uint32_t CElements = FragM * FragN / deviceWarpSize;
uint32_t ASize = AElements * sizeof(AType);
uint32_t BSize = BElements * sizeof(BType);
uint32_t CSize = CElements * sizeof(CType);
// Initialize A and B to all 1's, C to all 0's
std::vector<AType> h_a(AElements, static_cast<AType>(1));
std::vector<BType> h_b(BElements, static_cast<BType>(1));
std::vector<CType> h_c(CElements, static_cast<CType>(0));
std::vector<CType> h_out(CElements, static_cast<CType>(0));
AType* d_a;
BType* d_b;
CType* d_c;
CType* d_out;
HIP_CHECK_ERROR(hipMalloc(&d_a, ASize));
HIP_CHECK_ERROR(hipMalloc(&d_b, BSize));
HIP_CHECK_ERROR(hipMalloc(&d_c, CSize));
HIP_CHECK_ERROR(hipMalloc(&d_out, CSize));
// Copy inputs to device
HIP_CHECK_ERROR(hipMemcpy(d_a, h_a.data(), ASize, hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_b, h_b.data(), BSize, hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice));
const auto wave_size = getDeviceWaveSize();
test_sparse_accum_over_k<AType, BType, CType, FragM, FragN, FragK>
<<<1, wave_size>>>(d_a, d_b, d_c, d_out);
HIP_CHECK_ERROR(hipDeviceSynchronize());
HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_out, CSize, hipMemcpyDeviceToHost));
// Output should be FragK for all elements, because the inputs are all 1's
for(size_t i = 0; i < CElements; ++i)
{
// In sparse only half of the A values are non-zero, thus the /2.
CType expected = static_cast<CType>(FragK) / 2;
EXPECT_NEAR(h_out[i], expected, 1e-3);
}
HIP_CHECK_ERROR(hipFree(d_a));
HIP_CHECK_ERROR(hipFree(d_b));
HIP_CHECK_ERROR(hipFree(d_c));
HIP_CHECK_ERROR(hipFree(d_out));
}