[CK Tile] Unification work - mma transformations pipeline (#5508)

## Motivation

In this PR we showcase how the amdgcn structs could be used in a pipeline that does some extra pre/post processing.
For the sparse intrinsics, so far we compressed the A vector "on the fly" right before the execution of the builtin. This might introduce performance issues down the line if, for example, the user decided to chain multiple sparse builtins. We tackle this problem by creating a specific SparseCompressTransform.

A MmaPipelineBase is also created to facilitate those kind of higher level compositions of the amdgcn structs and is integrated to the existing WaveWiseMma prototype. There is an effort to facilitate future operations, like swizzle A/B, C transpose or double/quad attr num access through the MmaPipelineOptionFlags, but those are not yet defined and should do so in a future PR.
The pipeline base class is basically at the RFC stage.

We also create a runtime test for the existing WaveWiseMma, as well as one for the SparseMma pipeline.

## Technical Details

The goal should be to have the pipeline easily expandable. May the CRTP of the base class or the interface in general be insufficient or unable to handle all of our needs, then a design modification should be discussed.

## Test Plan

New tests are added.

## Test Result

Tests should pass.

---------

Signed-off-by: Chris Tsiaousis <chris.tsiaousis@streamhpc.com>
This commit is contained in:
chris-tsiaousis-hpc
2026-04-14 09:25:01 +02:00
committed by GitHub
parent 5eee93e67c
commit 89c5e67028
20 changed files with 1580 additions and 585 deletions

View File

@@ -19,14 +19,16 @@
#include "ck_tile/core/arch/mma/mfma/mfma_selector.hpp"
#include "ck_tile/core/arch/mma/mfma/mfma_traits.hpp"
#include "ck_tile/core/arch/mma/mfma/mfma_transforms.hpp"
#include "ck_tile/core/arch/mma/mma.hpp"
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
#include "ck_tile/core/arch/mma/mma_pipeline.hpp"
#include "ck_tile/core/arch/mma/mma_selector.hpp"
#include "ck_tile/core/arch/mma/mma_traits.hpp"
#include "ck_tile/core/arch/mma/mma_transforms.hpp"
#include "ck_tile/core/arch/mma/mma_wavewise.hpp"
#include "ck_tile/core/arch/mma/sparse/mfma/selector.hpp"
#include "ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse_selector.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse_traits.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse_transforms.hpp"

View File

@@ -204,6 +204,20 @@ struct Unsupported;
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
#include <concepts>
/**
* @concept HasExecSignature
* @brief Helper concept for exec signature check.
*/
template <typename MmaOp, typename... ExecArgs>
concept HasExecSignature = requires {
{
MmaOp::exec(typename MmaOp::AVecType{},
typename MmaOp::BVecType{},
typename MmaOp::CVecType{},
std::declval<ExecArgs>()...)
} -> std::convertible_to<typename MmaOp::CVecType>;
};
/**
* @concept MmaOpI
* @brief Expresses the meta-data interface required for each MmaOp policy.
@@ -213,7 +227,7 @@ template <typename MmaOp>
concept MmaOpI = requires(MmaOp op) {
// Requires an op context
typename MmaOp::OpType;
typename MmaOp::OpFamily;
{ MmaOp::OpFamily } -> std::convertible_to<MmaOpFamily>;
// Captures types for inputs / outputs to mma function
typename MmaOp::ADataType;
@@ -230,13 +244,8 @@ concept MmaOpI = requires(MmaOp op) {
{ MmaOp::kBRepeat } -> std::convertible_to<unsigned int>;
{ MmaOp::kCMPerLane } -> std::convertible_to<unsigned int>;
{ MmaOp::kCMNumAccess } -> std::convertible_to<unsigned int>;
// Static exec function
{
MmaOp::exec(
typename MmaOp::AVecType{}, typename MmaOp::BVecType{}, typename MmaOp::CVecType{})
} -> std::convertible_to<typename MmaOp::CVecType>;
};
{ MmaOp::kCompressionRatio } -> std::convertible_to<unsigned int>;
} && (HasExecSignature<MmaOp> || HasExecSignature<MmaOp, int>);
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
@@ -275,7 +284,7 @@ struct amdgcn_mma : amdgcn_mma_base<fp32_t, fp32_t, fp32_t, 1u, 1u, 1u, 1u, 1, 1
// clang-format on
{
// This is a default pass-through implementation that doesn't do anything practical.
CK_TILE_DEVICE static CVecType const&
CK_TILE_DEVICE static auto
exec(AVecType const& regsA, BVecType const& regsB, CVecType const& regsC)
{
// Prints once across all thread blocks and threads.

View File

@@ -28,9 +28,11 @@ struct MmaDefaultTransformsGfx9
// TODO: c++20 template <MmaOpI MmaOp, amdgcn_target_arch_id CompilerTarget>
// TODO: c++20 requires
template <typename MmaOp, typename CompilerTarget>
struct MmaTransformsDefaultSelector<MmaOp,
CompilerTarget,
enable_if_target_family_gfx9_t<CompilerTarget>>
struct MmaTransformsDefaultSelector<
MmaOp,
CompilerTarget,
enable_if_all<enable_if_target_family_gfx9_t<CompilerTarget>,
std::enable_if_t<MmaOp::OpFamily == MmaOpFamily::DENSE>>>
{
using SelectedTransforms = MmaDefaultTransformsGfx9;
};

View File

@@ -1,230 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "amdgcn_mma.hpp"
#include "mma_selector.hpp"
#include "mma_transforms.hpp"
#include "mfma/mfma.hpp"
#include "wmma/wmma.hpp"
namespace ck_tile::core::arch::mma {
/*! @enum MmaAccumPolicy
* @brief Accumulation order for Mma decomposition
*/
enum struct MmaAccumPolicy
{
// Decomposition and accumulation in row-major fragment order
ROW_MAJOR,
// Decomposition and accumulation in col-major fragment order
COL_MAJOR
};
/**
* @class Mma
* @brief Driver for the wave-tile Mma operation. Given a backend MmaOp implementation
* (e.g., mfma or wmma), this class performs fragment-wise (MmaTile) decomposition to
* matrix-multiply input WaveTiles of (A: WaveTileM x WaveTileK) x (B: WaveTileK x WaveTileN) and
* accumulates results into output WaveTile (C: WaveTileM x WaveTileN).
* @tparam ADataType Data type of input WaveTile A
* @tparam BDataType Data type of input WaveTile B
* @tparam CDataType Data type of input/output WaveTile C (accumulator)
* @tparam WaveTileM Mma WaveTile M dimension
* @tparam WaveTileN Mma WaveTile K dimension
* @tparam WaveTileK Mma WaveTile M dimension
* @tparam AccumPolicy The fragment order of the accum. registers (row or col major frag order)
* @tparam CompilerTarget The compiler target
* @tparam MmaOp Backend wrapper class that will perform the mma op (e.g., mfma or wmma)
* @tparam MmaTransforms The set of transforms to be applied to input/output WaveTiles
* @par This is an example of an Mma decomposition driver class that can be used in a wave-tile
* context. Given a WaveTile size, we can decompose the WaveTile into smaller mma op fragments
* that are natively supported by the hardware (e.g., mfma or wmma). The class also supports
* applying transforms to the input/output frags as needed (e.g., layout conversions, data type
* conversions, etc.). We may also specify the accumulation order (row-major or col-major) for the
* output WaveTile. This is a powerful example of how to build a flexible and reusable mma driver
* that can adapt to different hardware capabilities and requirements.
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
uint32_t WaveTileM,
uint32_t WaveTileN,
uint32_t WaveTileK,
MmaOpFamily OpFamily,
MmaAccumPolicy AccumPolicy = MmaAccumPolicy::ROW_MAJOR,
typename CompilerTarget =
decltype(get_compiler_target()), // TODO: c++20 amdgcn_target_arch_id GfxTargetId =
// get_compiler_target(),
typename MmaOp =
typename MmaDefaultSelector<ADataType, // TODO: c++20 MmaOpI MmaOp = typename
// MmaDefaultSelector<ADataType,
BDataType,
CDataType,
WaveTileM,
WaveTileN,
WaveTileK,
CompilerTarget,
OpFamily>::SelectedOp,
typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms =
typename MmaTransformsDefaultSelector<MmaOp, CompilerTarget>::SelectedTransforms>
struct WaveWiseMma
{
using FragWiseMmaOp = MmaOp;
// Fragment dimensions
constexpr static uint32_t FragM = MmaOp::kM;
constexpr static uint32_t FragN = MmaOp::kN;
constexpr static uint32_t FragK = MmaOp::kK;
// Fragment counts for decomposition
constexpr static uint32_t FragsM = WaveTileM / FragM;
constexpr static uint32_t FragsN = WaveTileN / FragN;
constexpr static uint32_t FragsK = WaveTileK / FragK;
constexpr static uint32_t FragsC = FragsM * FragsN;
// Vector types for packed registers in each fragment
using AVecType = typename MmaOp::AVecType;
using BVecType = typename MmaOp::BVecType;
using CVecType = typename MmaOp::CVecType;
// Buffer types for WaveTiles
using ABufferType = AVecType[FragsM][FragsK];
using BBufferType = BVecType[FragsN][FragsK];
using CBufferType = CVecType[FragsM][FragsN];
// Transforms
using ATransform = typename MmaTransforms::ATransform;
using BTransform = typename MmaTransforms::BTransform;
using CTransform = typename MmaTransforms::CTransform;
using DTransform = typename MmaTransforms::DTransform;
// Sanity checks
static_assert(WaveTileM >= FragM, "WaveTileM must be larger than FragM");
static_assert(WaveTileN >= FragN, "WaveTileN must be larger than FragN");
static_assert(WaveTileK >= FragK, "WaveTileK must be larger than FragK");
static_assert(WaveTileM % FragM == 0u, "WaveTileM must be a multiple of FragM");
static_assert(WaveTileN % FragN == 0u, "WaveTileN must be a multiple of FragN");
static_assert(WaveTileK % FragK == 0u, "WaveTileK must be a multiple of FragK");
private:
template <typename DstT, typename SrcT>
CK_TILE_DEVICE static auto formatBuffer(SrcT const& inputBuffer)
{
// TODO: Implement formatting logic as needed.
// This is intended to convert input WaveTiles to the native vector types
// required by the FragWiseMma operation for iteration
static_assert(sizeof(DstT) == sizeof(SrcT), "Size mismatch in formatBuffer");
return reinterpret_cast<DstT const&>(inputBuffer);
}
template <typename DstT, typename SrcT>
CK_TILE_DEVICE static auto formatBuffer(SrcT& inputBuffer)
{
// TODO: Implement formatting logic as needed.
// This is intended to convert input WaveTiles to the native vector types
// required by the FragWiseMma operation for iteration
static_assert(sizeof(DstT) == sizeof(SrcT), "Size mismatch in formatBuffer");
return reinterpret_cast<DstT&>(inputBuffer);
}
/*! @brief Execute Mma in row-major accumulation order.
* @tparam VecTA The input WaveTile A vector type
* @tparam VecTB The input WaveTile B vector type
* @tparam VecTC The input/output WaveTile C vector type
*/
template <typename VecTA, typename VecTB, typename VecTC>
CK_TILE_DEVICE static decltype(auto) exec_col_major(VecTA&& a, VecTB&& b, VecTC&& accum)
{
// We implement an example wave-tile pipeline here.
// First, we apply the necessary transforms to the input fragments,
// then we convert the result into buffers of native vector formats
// that we can easily index. Native vector formats are necessary inputs
// to the given MmaOp exec function.
auto a_frag = formatBuffer<ABufferType>(ATransform::exec(a));
auto b_frag = formatBuffer<BBufferType>(BTransform::exec(b));
auto c_frag = formatBuffer<CBufferType>(CTransform::exec(accum));
// "Col-major" accumulation over the M-dimension fragments first.
// Pseudo code here, but we would basically iterate over the fragments in col-major order
for(uint32_t bn = 0u; bn < FragsN; ++bn)
{
for(uint32_t bm = 0u; bm < FragsM; ++bm)
{
for(uint32_t bk = 0u; bk < FragsK; ++bk)
{
c_frag[bm][bn] =
FragWiseMmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]);
}
}
}
// Convert native vector results back to the output WaveTile format
// and then return after we apply the final output transform.
return DTransform::exec(formatBuffer<std::decay_t<VecTC>>(c_frag));
}
/*! @brief Execute Mma in row-major accumulation order.
* @tparam VecTA The input WaveTile A vector type
* @tparam VecTB The input WaveTile B vector type
* @tparam VecTC The input/output WaveTile C vector type
*/
template <typename VecTA, typename VecTB, typename VecTC>
CK_TILE_DEVICE static decltype(auto) exec_row_major(VecTA&& a, VecTB&& b, VecTC&& accum)
{
// We implement an example wave-tile pipeline here.
// First, we apply the necessary transforms to the input WaveTiles,
// then we convert the result into buffers of native vector formats
// that we can easily index. Native vector formats are necessary inputs
// to the given MmaOp exec function.
auto a_frag = formatBuffer<ABufferType>(ATransform::exec(a));
auto b_frag = formatBuffer<BBufferType>(BTransform::exec(b));
auto c_frag = formatBuffer<CBufferType>(CTransform::exec(accum));
// "Row-major" accumulation over the N-dimension fragments first.
// Pseudo code here, but we would basically iterate over the fragments in row-major order.
// We also have to ensure that the incoming vector WaveTiles are converted to native vector
// types before passing to the FragWiseMma exec function.
for(uint32_t bm = 0u; bm < FragsM; ++bm)
{
for(uint32_t bn = 0u; bn < FragsN; ++bn)
{
for(uint32_t bk = 0u; bk < FragsK; ++bk)
{
c_frag[bm][bn] =
FragWiseMmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]);
}
}
}
// Convert native vector results back to the output WaveTile format
// and then return after we apply the final output transform.
return DTransform::exec(formatBuffer<std::decay_t<VecTC>>(c_frag));
}
public:
/*! @brief Forward to Mma operation with specified accumulation order.
* @tparam VecTA The input WaveTile A vector type
* @tparam VecTB The input WaveTile B vector type
* @tparam VecTC The input/output WaveTile C vector type
*/
template <typename VecTA, typename VecTB, typename VecTC>
CK_TILE_DEVICE static decltype(auto) exec(VecTA&& a, VecTB&& b, VecTC&& accum)
{
if constexpr(AccumPolicy == MmaAccumPolicy::ROW_MAJOR)
{
return exec_row_major(
std::forward<VecTA>(a), std::forward<VecTB>(b), std::forward<VecTC>(accum));
}
else // if constexpr(AccumPolicy == MmaAccumPolicy::COL_MAJOR)
{
return exec_col_major(
std::forward<VecTA>(a), std::forward<VecTB>(b), std::forward<VecTC>(accum));
}
}
};
} // namespace ck_tile::core::arch::mma

View File

@@ -0,0 +1,299 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "amdgcn_mma.hpp"
#include "mma_selector.hpp"
#include "mma_traits.hpp"
#include "mma_transforms.hpp"
namespace ck_tile::core::arch::mma {
/*! @enum MmaPipelineOptionFlag
* @brief Individual option flags for configuring MmaPipeline behavior.
*/
enum struct MmaPipelineOptionFlag : unsigned
{
NONE = 0x0, ///< No flags set
ABSwap = 0x1, ///< Swap A and B inputs to transpose the C output
COMPRESS_A = 0x2, ///< Enable compressed (sparse) A matrix input
};
/**
* @struct MmaPipelineOptionFlags
* @brief Type-safe bitmask wrapper for combining @ref MmaPipelineOptionFlag values.
* @par Provides bitwise OR, AND, NOT, and equality operators for composing
* and querying pipeline option flags.
*/
struct MmaPipelineOptionFlags
{
using Type = std::underlying_type_t<MmaPipelineOptionFlag>;
explicit constexpr MmaPipelineOptionFlags() : mFlags(0) {}
explicit constexpr MmaPipelineOptionFlags(Type value) : mFlags(value) {}
constexpr MmaPipelineOptionFlags(MmaPipelineOptionFlag singleFlag) : mFlags(toType(singleFlag))
{
}
constexpr MmaPipelineOptionFlags(const MmaPipelineOptionFlags& original)
: mFlags(original.mFlags)
{
}
constexpr MmaPipelineOptionFlags& operator|=(MmaPipelineOptionFlag addValue)
{
mFlags |= toType(addValue);
return *this;
}
constexpr MmaPipelineOptionFlags operator|(MmaPipelineOptionFlag addValue) const
{
MmaPipelineOptionFlags result(*this);
result |= addValue;
return result;
}
constexpr MmaPipelineOptionFlags& operator&=(MmaPipelineOptionFlag maskValue)
{
mFlags &= toType(maskValue);
return *this;
}
constexpr MmaPipelineOptionFlags operator&(MmaPipelineOptionFlag maskValue) const
{
MmaPipelineOptionFlags result(*this);
result &= maskValue;
return result;
}
constexpr MmaPipelineOptionFlags operator~() const
{
MmaPipelineOptionFlags result(*this);
result.mFlags = ~result.mFlags;
return result;
}
constexpr bool testFlag(MmaPipelineOptionFlag flag) const
{
return (flag == MmaPipelineOptionFlag::NONE) ? mFlags == toType(flag) : *this & flag;
}
constexpr operator bool() const { return mFlags != toType(MmaPipelineOptionFlag::NONE); }
constexpr bool operator==(Type rhs) const { return mFlags == rhs; }
private:
Type mFlags;
static constexpr Type toType(MmaPipelineOptionFlag f) { return static_cast<Type>(f); }
};
constexpr bool operator==(MmaPipelineOptionFlags::Type lhs, const MmaPipelineOptionFlags& rhs)
{
return rhs == lhs;
}
/**
* @class MmaPipelineBase
* @brief CRTP base class that implements the common Mma pipeline logic shared by
* all concrete pipeline drivers (e.g., dense wave-wise, sparse, etc.).
*
* @tparam Flags_ Compile-time bitmask of @ref MmaPipelineOptionFlag controlling
* pipeline behavior (e.g., C transposition, A compression).
* @tparam Derived The concrete CRTP-derived pipeline class. Must expose:
* - Type aliases: @c InternalAVecT, @c InternalBVecT, @c InternalCVecT,
* @c CVecType, @c MmaOp
* - Transform aliases: @c ATransform, @c BTransform, @c CTransform,
* @c DTransform
* - A static @c execImpl(std::tuple<A,B,C>&) method.
*
* @par The pipeline performs the following steps in @c exec():
* 1. Apply pre-transforms and format input buffers (A, B, C).
* 2. Delegate to @c Derived::execImpl for the actual mma loop.
* 3. Apply post-transform and format the output buffer (D) back to the user type.
* When @c ABSwap is set, the A and B inputs are swapped before step 1.
*/
// TODO: c++20: use MmaPipelineOptionFlags directly
template <MmaPipelineOptionFlags::Type Flags_, typename Derived>
struct MmaPipelineBase
{
static constexpr auto Flags = MmaPipelineOptionFlags(Flags_);
private:
/**
* @brief Reconstruct a tuple with its first element passed through @c formatBuffer
* while preserving all remaining elements unchanged.
* @tparam DstT Target type for the formatted first element.
* @tparam SrcT Forwarding-reference type of the input tuple.
* @tparam Is Index pack for elements 1..N-1 of the tuple.
* @param inputTuple The source tuple whose first element will be formatted.
* @return A new tuple with the formatted first element and the remaining elements forwarded.
*/
template <typename DstT, typename SrcT, std::size_t... Is>
CK_TILE_DEVICE static auto formatBufferTupleImpl(SrcT&& inputTuple, std::index_sequence<Is...>)
{
auto&& first_elem = std::get<0>(std::forward<SrcT>(inputTuple));
using FirstElemResultType =
decltype(formatBuffer<DstT>(std::forward<decltype(first_elem)>(first_elem)));
using InputTupleType = ck_tile::remove_cvref_t<SrcT>;
return std::tuple<FirstElemResultType, std::tuple_element_t<Is + 1, InputTupleType>...>(
formatBuffer<DstT>(std::forward<decltype(first_elem)>(first_elem)),
std::get<Is + 1>(std::forward<SrcT>(inputTuple))...);
}
/**
* @brief Format (reinterpret-cast) a buffer to the hardware-native vector type @p DstT.
*
* Three cases are handled:
* - **Tuple**: recursively format the first element via @c formatBufferTupleImpl,
* preserving any metadata in the remaining tuple elements.
* - **Array / Pointer**: forwarded unchanged.
* - **Scalar / Vector**: reinterpret-cast to @p DstT (sizes must match).
*
* @tparam DstT The target hardware vector type.
* @tparam SrcT Forwarding-reference type of the input buffer.
* @param inputBuffer The buffer to format.
* @return A reference (or value) of type @p DstT corresponding to @p inputBuffer.
*/
template <typename DstT, typename SrcT>
CK_TILE_DEVICE static decltype(auto) formatBuffer(SrcT&& inputBuffer)
{
using DecayedSrcT = ck_tile::remove_cvref_t<SrcT>;
// If SrcT is a tuple, extract the first element (the vector) and format it
// while preserving all remaining elements (metadata)
if constexpr(is_std_tuple_v<DecayedSrcT>)
{
// Create index sequence for all remaining elements (skip first)
constexpr std::size_t tuple_size = std::tuple_size_v<DecayedSrcT>;
return formatBufferTupleImpl<DstT>(std::forward<SrcT>(inputBuffer),
std::make_index_sequence<tuple_size - 1>{});
}
else if constexpr(std::is_array_v<DecayedSrcT> || std::is_pointer_v<DecayedSrcT>)
{
return std::forward<SrcT>(inputBuffer);
}
else
{
static_assert(sizeof(DstT) == sizeof(DecayedSrcT), "Size mismatch in formatBuffer");
using QualifiedDstT =
std::conditional_t<std::is_const_v<DecayedSrcT>, DstT const, DstT>;
return reinterpret_cast<QualifiedDstT&>(inputBuffer);
}
}
protected:
/** @brief Query whether a specific @ref MmaPipelineOptionFlag is set. */
template <MmaPipelineOptionFlag Flag>
constexpr CK_TILE_DEVICE static bool hasFlag()
{
return Flags.testFlag(Flag);
}
/**
* @brief Apply a transform **then** format the result to @p DstT.
* Used for input operands (A, B, C) before the mma loop.
*/
template <typename DstT, typename Transform, typename... Args>
CK_TILE_DEVICE static auto preApplyTransform(Args&&... args)
{
return formatBuffer<DstT>(Transform::exec(std::forward<Args>(args)...));
}
/**
* @brief Format a buffer to @p DstT **then** apply a transform.
* Used for the output operand (D) after the mma loop.
*/
template <typename DstT, typename Transform, typename... Args>
CK_TILE_DEVICE static auto postApplyTransform(Args&&... args)
{
return Transform::exec(formatBuffer<DstT>(std::forward<Args>(args)...));
}
/**
* @brief Apply the per-operand pre-transforms and buffer formatting to A, B, and C.
* @return A @c std::tuple of the transformed (A, B, C) vectors ready for the mma loop.
*/
template <typename ATransformInputs, typename BTransformInputs, typename CTransformInputs>
CK_TILE_DEVICE static decltype(auto)
applyTransformsToInputs(ATransformInputs&& a, BTransformInputs&& b, CTransformInputs&& accum)
{
using InternalAVecT = typename Derived::InternalAVecT;
using InternalBVecT = typename Derived::InternalBVecT;
using InternalCVecT = typename Derived::InternalCVecT;
using ATransform = typename Derived::ATransform;
using BTransform = typename Derived::BTransform;
using CTransform = typename Derived::CTransform;
return std::make_tuple(
preApplyTransform<InternalAVecT, ATransform>(std::forward<ATransformInputs>(a)),
preApplyTransform<InternalBVecT, BTransform>(std::forward<BTransformInputs>(b)),
preApplyTransform<InternalCVecT, CTransform>(std::forward<CTransformInputs>(accum)));
}
/**
* @brief Apply the post-transform and buffer formatting to the C (accumulator) output.
* @param vecs The (A, B, C) tuple after @c execImpl; only C is consumed.
* @return The final D output in the user-facing vector type.
*/
template <typename ATransformResult, typename BTransformResult, typename CTransformResult>
CK_TILE_DEVICE static auto
applyTransformToOutput(std::tuple<ATransformResult, BTransformResult, CTransformResult>&& vecs)
{
auto&& [a_result, b_result, c_result] = vecs;
static_assert(!is_std_tuple_v<decltype(c_result)>,
"If CTransform returns more than the vector, update this function.");
using CVecT = typename Derived::CVecType;
using DTransform = typename Derived::DTransform;
return postApplyTransform<CVecT, DTransform>(c_result);
}
public:
/**
* @brief Entry point: execute the full Mma pipeline (transforms + mma loop + output).
* @tparam VecTA Type of the A WaveTile buffer.
* @tparam VecTB Type of the B WaveTile buffer.
* @tparam VecTC Type of the C (accumulator) WaveTile buffer.
* @param a Input WaveTile A.
* @param b Input WaveTile B.
* @param accum Input/output accumulator WaveTile C.
* @return The output WaveTile D after accumulation and post-transform.
*/
template <typename VecTA, typename VecTB, typename VecTC>
CK_TILE_DEVICE static decltype(auto) exec(VecTA&& a, VecTB&& b, VecTC&& accum)
{
if constexpr(MmaOpTraits<typename Derived::MmaOp>::IsSupported)
{
auto transformed_inputs = applyTransformsToInputs(
hasFlag<MmaPipelineOptionFlag::ABSwap>() ? std::forward<VecTB>(b)
: std::forward<VecTA>(a),
hasFlag<MmaPipelineOptionFlag::ABSwap>() ? std::forward<VecTA>(a)
: std::forward<VecTB>(b),
std::forward<VecTC>(accum));
Derived::execImpl(transformed_inputs);
return applyTransformToOutput(std::move(transformed_inputs));
}
else
{
// Return the unsupported exec. This should print a runtime warning. (amdgcn_mma.hpp)
// Code should not reach here, but HOST/DEVICE compile passes are
// weirdly intertwined and instead of having constexpr in the calling
// site (tests) we do this. See also changes by this commit.
return Derived::MmaOp::exec({}, {}, {});
}
}
};
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
#include <concepts>
/**
* @concept MmaPipelineI
* @brief Expresses the meta-data interface required for a CRTP MmaPipeline.
*/
template <typename Derived, MmaPipelineOptionFlags::Type Flags>
concept MmaPipelineInterface = std::derived_from<Derived, MmaPipelineBase<Flags, Derived>>;
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
} // namespace ck_tile::core::arch::mma

View File

@@ -72,3 +72,4 @@ concept MmaSelectorI = requires(MmaSelector op) {
// Include the implementations
#include "wmma/wmma_selector.hpp"
#include "mfma/mfma_selector.hpp"
#include "sparse/sparse_selector.hpp"

View File

@@ -18,6 +18,18 @@ struct PassThroughTransform
}
};
/**
* @struct MmaDefaultPassThroughTransforms
* @brief Implements the default MMA transforms
*/
struct MmaDefaultPassThroughTransforms
{
using ATransform = PassThroughTransform;
using BTransform = PassThroughTransform;
using CTransform = PassThroughTransform;
using DTransform = PassThroughTransform;
};
/**
* @class MmaTransformsDefaultSelector
* @brief Default selector for MmaTransforms based on MmaOp and CompilerTarget
@@ -27,7 +39,10 @@ struct PassThroughTransform
*/
template <typename MmaOp, typename CompilerTarget, typename Enable = void>
// TODO: c++20 template <MmaOpI MmaOp, amdgcn_target_arch_id CompilerTarget, typename Enable = void>
struct MmaTransformsDefaultSelector;
struct MmaTransformsDefaultSelector
{
using SelectedTransforms = MmaDefaultPassThroughTransforms;
};
#if CK_TILE_CONCEPTS

View File

@@ -0,0 +1,177 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "amdgcn_mma.hpp"
#include "mma_pipeline.hpp"
#include "mma_selector.hpp"
#include "mma_transforms.hpp"
#include "mfma/mfma.hpp"
#include "wmma/wmma.hpp"
#include <tuple>
namespace ck_tile::core::arch::mma {
/*! @enum MmaAccumPolicy
* @brief Accumulation order for Mma decomposition
*/
enum struct MmaAccumPolicy
{
// Decomposition and accumulation in row-major fragment order
ROW_MAJOR,
// Decomposition and accumulation in col-major fragment order
COL_MAJOR
};
namespace dense::wavewise::detail {
// TODO: c++20: return MmaPipelineOptionFlags directly
template <bool SwapAB>
constexpr inline int getPipelineFlags()
{
return static_cast<int>(SwapAB ? MmaPipelineOptionFlag::ABSwap : MmaPipelineOptionFlag::NONE);
}
} // namespace dense::wavewise::detail
/**
* @class Mma
* @brief Driver for the wave-tile Mma operation. Given a backend MmaOp implementation
* (e.g., mfma or wmma), this class performs fragment-wise (MmaTile) decomposition to
* matrix-multiply input WaveTiles of (A: WaveTileM x WaveTileK) x (B: WaveTileK x WaveTileN) and
* accumulates results into output WaveTile (C: WaveTileM x WaveTileN).
* @tparam ADataType Data type of input WaveTile A
* @tparam BDataType Data type of input WaveTile B
* @tparam CDataType Data type of input/output WaveTile C (accumulator)
* @tparam WaveTileM Mma WaveTile M dimension
* @tparam WaveTileN Mma WaveTile K dimension
* @tparam WaveTileK Mma WaveTile M dimension
* @tparam AccumPolicy The fragment order of the accum. registers (row or col major frag order)
* @tparam SwapAB Swaps A and B input vectors
* @tparam CompilerTarget The compiler target
* @tparam MmaOp_ Backend wrapper class that will perform the mma op (e.g., mfma or wmma)
* @tparam MmaTransforms The set of transforms to be applied to input/output WaveTiles
* @par This is an example of an Mma decomposition driver class that can be used in a wave-tile
* context. Given a WaveTile size, we can decompose the WaveTile into smaller mma op fragments
* that are natively supported by the hardware (e.g., mfma or wmma). The class also supports
* applying transforms to the input/output frags as needed (e.g., layout conversions, data type
* conversions, etc.). We may also specify the accumulation order (row-major or col-major) for the
* output WaveTile. This is a powerful example of how to build a flexible and reusable mma driver
* that can adapt to different hardware capabilities and requirements.
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
uint32_t WaveTileM,
uint32_t WaveTileN,
uint32_t WaveTileK,
MmaOpFamily OpFamily,
MmaAccumPolicy AccumPolicy = MmaAccumPolicy::ROW_MAJOR,
bool SwapAB = false,
typename CompilerTarget =
decltype(get_compiler_target()), // TODO: c++20 amdgcn_target_arch_id GfxTargetId =
// get_compiler_target(),
typename MmaOp_ =
typename MmaDefaultSelector<ADataType, // TODO: c++20 MmaOpI MmaOp = typename
// MmaDefaultSelector<ADataType,
BDataType,
CDataType,
WaveTileM,
WaveTileN,
WaveTileK,
CompilerTarget,
OpFamily>::SelectedOp,
typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms =
typename MmaTransformsDefaultSelector<MmaOp_, CompilerTarget>::SelectedTransforms>
// clang-format off
struct WaveWiseMmaPipeline : public MmaPipelineBase<dense::wavewise::detail::getPipelineFlags<SwapAB>(),
WaveWiseMmaPipeline<ADataType, BDataType, CDataType, WaveTileM, WaveTileN, WaveTileK, OpFamily, AccumPolicy, SwapAB, CompilerTarget, MmaOp_, MmaTransforms>>
{
using Base = MmaPipelineBase<dense::wavewise::detail::getPipelineFlags<SwapAB>(),
WaveWiseMmaPipeline<ADataType, BDataType, CDataType, WaveTileM, WaveTileN, WaveTileK, OpFamily, AccumPolicy, SwapAB, CompilerTarget, MmaOp_, MmaTransforms>>;
// clang-format on
using MmaOp = MmaOp_;
// Fragment dimensions
constexpr static uint32_t FragM = MmaOp::kM;
constexpr static uint32_t FragN = MmaOp::kN;
constexpr static uint32_t FragK = MmaOp::kK;
// Fragment counts for decomposition
constexpr static uint32_t FragsM = WaveTileM / FragM;
constexpr static uint32_t FragsN = WaveTileN / FragN;
constexpr static uint32_t FragsK = WaveTileK / FragK;
constexpr static uint32_t FragsC = FragsM * FragsN;
// Vector types for packed registers in each fragment
using InternalAVecT = typename MmaOp::AVecType;
using InternalBVecT = typename MmaOp::BVecType;
using InternalCVecT = typename MmaOp::CVecType;
// Buffer types for WaveTiles
using AVecType = InternalAVecT[FragsM][FragsK];
using BVecType = InternalBVecT[FragsN][FragsK];
using CVecType = InternalCVecT[FragsM][FragsN];
// Transforms
using ATransform = typename MmaTransforms::ATransform;
using BTransform = typename MmaTransforms::BTransform;
using CTransform = typename MmaTransforms::CTransform;
using DTransform = typename MmaTransforms::DTransform;
// Sanity checks
static_assert(WaveTileM >= FragM, "WaveTileM must be larger than FragM");
static_assert(WaveTileN >= FragN, "WaveTileN must be larger than FragN");
static_assert(WaveTileK >= FragK, "WaveTileK must be larger than FragK");
static_assert(WaveTileM % FragM == 0u, "WaveTileM must be a multiple of FragM");
static_assert(WaveTileN % FragN == 0u, "WaveTileN must be a multiple of FragN");
static_assert(WaveTileK % FragK == 0u, "WaveTileK must be a multiple of FragK");
template <typename VecTA, typename VecTB, typename VecTC>
CK_TILE_DEVICE static void execImpl(std::tuple<VecTA, VecTB, VecTC>& vecs)
{
auto& [a_frag, b_frag, c_frag] = vecs;
if constexpr(AccumPolicy == MmaAccumPolicy::ROW_MAJOR)
{
// "Row-major" accumulation over the N-dimension fragments first.
// Pseudo code here, but we would basically iterate over the fragments in row-major
// order. We also have to ensure that the incoming vector WaveTiles are converted to
// native vector types before passing to the FragWiseMma exec function.
for(uint32_t bm = 0u; bm < FragsM; ++bm)
{
for(uint32_t bn = 0u; bn < FragsN; ++bn)
{
for(uint32_t bk = 0u; bk < FragsK; ++bk)
{
c_frag[bm][bn] =
MmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]);
}
}
}
}
else if constexpr(AccumPolicy == MmaAccumPolicy::COL_MAJOR)
{
// "Col-major" accumulation over the M-dimension fragments first.
// Pseudo code here, but we would basically iterate over the blocks in col-major order
for(uint32_t bn = 0u; bn < FragsN; ++bn)
{
for(uint32_t bm = 0u; bm < FragsM; ++bm)
{
for(uint32_t bk = 0u; bk < FragsK; ++bk)
{
c_frag[bm][bn] =
MmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]);
}
}
}
}
else
{
static_assert(false);
}
}
};
} // namespace ck_tile::core::arch::mma

View File

@@ -0,0 +1,100 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/arch/mma/mma_pipeline.hpp"
#include "ck_tile/core/arch/mma/mma_selector.hpp"
#include "ck_tile/core/arch/mma/mma_traits.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse_transforms.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include <cstdint>
#include <type_traits>
namespace ck_tile::core::arch::mma {
namespace sparse::detail {
// TODO: c++20: return MmaPipelineOptionFlags directly
constexpr inline int getPipelineFlags()
{
return static_cast<int>(MmaPipelineOptionFlag::COMPRESS_A);
}
} // namespace sparse::detail
template <typename ADataType,
typename BDataType,
typename CDataType,
uint32_t FragM,
uint32_t FragN,
uint32_t FragK,
typename CompilerTarget =
decltype(get_compiler_target()), // TODO: c++20 amdgcn_target_arch_id GfxTargetId =
// get_compiler_target(),
typename MmaOp_ =
typename MmaDefaultSelector<ADataType, // TODO: c++20 MmaOpI MmaOp = typename
// MmaDefaultSelector<ADataType,
BDataType,
CDataType,
FragM,
FragN,
FragK,
CompilerTarget,
MmaOpFamily::SPARSE>::SelectedOp,
typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms =
typename MmaTransformsDefaultSelector<MmaOp_, CompilerTarget>::SelectedTransforms>
// clang-format off
struct SparseMmaPipeline : public MmaPipelineBase<sparse::detail::getPipelineFlags(), SparseMmaPipeline<ADataType, BDataType, CDataType, FragM, FragN, FragK, CompilerTarget, MmaOp_, MmaTransforms>>
{
using Base = MmaPipelineBase<sparse::detail::getPipelineFlags(), SparseMmaPipeline<ADataType, BDataType, CDataType, FragM, FragN, FragK, CompilerTarget, MmaOp_, MmaTransforms>>;
// clang-format on
static_assert(!Base::template hasFlag<MmaPipelineOptionFlag::ABSwap>(),
"Cannot transpose C in sparse intrinsics.");
using MmaOp = MmaOp_; // Expose the selected MmaOp
// Calculate the uncompressed A vector type
struct ExternalAVecCalculator
{
using AVecTraits = vector_traits<typename MmaOp::AVecType>;
static constexpr index_t ASize = AVecTraits::vector_size * MmaOp::kCompressionRatio;
using AVecType = ext_vector_t<typename AVecTraits::scalar_type, ASize>;
};
// Expose caller-side vector types
using AVecType = typename ExternalAVecCalculator::AVecType;
using BVecType = typename MmaOp::BVecType;
using CVecType = typename MmaOp::CVecType;
// Expose internal vector types
using InternalAVecT = typename MmaOp::AVecType;
using InternalBVecT = typename MmaOp::BVecType;
using InternalCVecT = typename MmaOp::CVecType;
// Transforms
using ATransform = typename MmaTransforms::ATransform;
using BTransform = typename MmaTransforms::BTransform;
using CTransform = typename MmaTransforms::CTransform;
using DTransform = typename MmaTransforms::DTransform;
template <typename ATransformResult, typename BTransformResult, typename CTransformResult>
CK_TILE_DEVICE static void
execImpl(std::tuple<ATransformResult, BTransformResult, CTransformResult>& vecs)
{
checkATransformResult<ATransformResult>();
auto& [a_result, b_vec, c_vec] = vecs;
auto& [a_vec, idx] = a_result;
c_vec = MmaOp::exec(a_vec, b_vec, c_vec, idx);
}
private:
// Type check helper - not a device function, so std::declval is available
template <typename ATransformResult>
static constexpr void checkATransformResult()
{
using ExternalAvecRef = std::add_lvalue_reference_t<AVecType>;
static_assert(std::is_same_v<ATransformResult,
decltype(ATransform::exec(std::declval<ExternalAvecRef>()))>);
}
};
} // namespace ck_tile::core::arch::mma

View File

@@ -6,22 +6,101 @@
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
#include "ck_tile/core/arch/mma/mma_transforms.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include <cstdint>
namespace ck_tile::core::arch::mma {
namespace sparse::detail {
/**
* @struct MmaDefaultTransformsSparse
* @brief Compress A vector for 2:4 structured sparsity instruction by moving all non-zero
* elements into lower part of a_vec to half its effective size.
* @param a_vec Vector to be compressed.
* @tparam ADataType The data type of a_vec
* @tparam CompressedSize The target compression size
* @tparam AVec The vector type of a_vec (deduced)
* @return Packed 32bit word containing **CompressedSize** 2bit fields.
* Each field encodes the original position (03) of the corresponding
* nonzero element in the input. If fewer than CompressedSize
* nonzeros are found, remaining fields default to 2 (see below).
*/
template <typename ADataType, index_t CompressedSize, typename AVec>
static CK_TILE_DEVICE int32_t compress_a_impl(AVec& a_vec)
{
// idx holds one 2bit index per output element (total CompressedSize entries).
// It is initialized to the pattern 0b10 for every field. This matches
// what the hardware expects when there are fewer than two nonzero values
// in a 4element group the unused output is treated as coming from slot 2.
// The loop below will clear and set each field as real nonzeros are seen.
int32_t idx = 0;
static_for<0, CompressedSize, 1>{}([&](auto k) { idx |= (2u << (2u * k)); });
static_for<0, CompressedSize / 2, 1>{}([&](auto i) {
ADataType nonzero_elems[2] = {a_vec[i * 4 + 2], a_vec[i * 4 + 3]};
int32_t non_zero_pos = 0;
static_for<0, 4, 1>{}([&](auto j) {
if(static_cast<float>(a_vec[i * 4 + j]) != 0.0f)
{
nonzero_elems[non_zero_pos] = a_vec[i * 4 + j];
// clear the twobit field for this output and insert j
idx &= ~(0b11u << (2u * (i * 2 + non_zero_pos)));
idx |= static_cast<uint32_t>(j) << (2u * (i * 2 + non_zero_pos));
++non_zero_pos;
}
});
a_vec[i * 2] = nonzero_elems[0];
a_vec[i * 2 + 1] = nonzero_elems[1];
});
return idx;
}
} // namespace sparse::detail
/**
* @class SparseCompressTransform
* @brief Performs 2:4 structured sparsity compression to the vector v and produces an index mask.
* @note Returns a tuple of two. The first element is the vector v with the same scalar type but
* its size halved. The second element is the index mask.
*/
template <index_t CompressionRatio>
struct SparseCompressTransform
{
template <typename VecType>
CK_TILE_DEVICE static decltype(auto) exec(VecType& v)
{
using VecTraits = vector_traits<remove_cvref_t<VecType>>;
using ScalarT = typename VecTraits::scalar_type;
static constexpr auto VecN = VecTraits::vector_size;
static constexpr index_t CompressedSize = VecN / CompressionRatio;
using VecCompressed = ext_vector_t<ScalarT, CompressedSize>;
static_assert(VecN % CompressionRatio == 0, "VecN must be divisible by CompressionRatio");
static_assert(CompressedSize > 0, "CompressedSize must be > 0");
const auto idx = sparse::detail::compress_a_impl<ScalarT, CompressedSize>(v);
// TODO c++20: Use bit_cast
return std::tuple<VecCompressed&, int32_t>(
*std::launder(reinterpret_cast<VecCompressed*>(&v)), idx);
}
};
/**
* @class MmaDefaultTransformsSparse
* @brief Implements the default transforms for Sparse
*
* For 2:4 structured sparsity with inline register metadata:
* - ATransform: Pass-through (sparse operands formatted in Exec) TODO!
* - ATransform: 2:4 structured sparsity compression
* - BTransform: Pass-through (sparse operands already formatted)
* - CTransform: Pass-through (input accumulator)
* - DTransform: Pass-through (output accumulator as-is)
*/
template <index_t CompressionRatio>
struct MmaDefaultTransformsSparse
{
using ATransform = PassThroughTransform;
using ATransform = SparseCompressTransform<CompressionRatio>;
using BTransform = PassThroughTransform;
using CTransform = PassThroughTransform;
using DTransform = PassThroughTransform;
@@ -42,7 +121,7 @@ struct MmaTransformsDefaultSelector<MmaOp,
CompilerTarget,
std::enable_if_t<MmaOp::OpFamily == MmaOpFamily::SPARSE>>
{
using SelectedTransforms = MmaDefaultTransformsSparse;
using SelectedTransforms = MmaDefaultTransformsSparse<MmaOp::kCompressionRatio>;
};
} // namespace ck_tile::core::arch::mma

View File

@@ -86,9 +86,11 @@ struct MmaDefaultTransformsGfx12
template <typename MmaOp, typename CompilerTarget>
// TODO: c++20 template <MmaOpI MmaOp, amdgcn_target_arch_id GfxTargetId>
// TODO: c++20 requires
struct MmaTransformsDefaultSelector<MmaOp,
CompilerTarget,
enable_if_target_family_gfx11_t<CompilerTarget>>
struct MmaTransformsDefaultSelector<
MmaOp,
CompilerTarget,
enable_if_all<enable_if_target_family_gfx11_t<CompilerTarget>,
std::enable_if_t<MmaOp::OpFamily == MmaOpFamily::DENSE>>>
{
using SelectedTransforms = MmaDefaultTransformsGfx11;
};
@@ -102,9 +104,11 @@ struct MmaTransformsDefaultSelector<MmaOp,
template <typename MmaOp, typename CompilerTarget>
// TODO: c++20 template <MmaOpI MmaOp, amdgcn_target_arch_id GfxTargetId>
// TODO: c++20 requires
struct MmaTransformsDefaultSelector<MmaOp,
CompilerTarget,
enable_if_target_family_gfx12_t<CompilerTarget>>
struct MmaTransformsDefaultSelector<
MmaOp,
CompilerTarget,
enable_if_all<enable_if_target_family_gfx12_t<CompilerTarget>,
std::enable_if_t<MmaOp::OpFamily == MmaOpFamily::DENSE>>>
{
using SelectedTransforms = MmaDefaultTransformsGfx12;
};

View File

@@ -209,4 +209,21 @@ template <typename ADataType, typename BDataType>
using largest_type_t =
std::conditional_t<sizeof(ADataType) >= sizeof(BDataType), ADataType, BDataType>;
/**
* @brief Type trait to detect whether a type is a @c std::tuple specialization.
* @tparam T The type to inspect.
*/
template <typename T>
struct is_std_tuple : std::false_type
{
};
template <typename... Args>
struct is_std_tuple<std::tuple<Args...>> : std::true_type
{
};
template <typename T>
static constexpr bool is_std_tuple_v = is_std_tuple<T>::value;
} // namespace ck_tile

View File

@@ -5,52 +5,9 @@
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
namespace ck_tile {
/**
* @brief Compress A vector for 2:4 structured sparsity instruction by moving all non-zero
* elements into lower part of a_vec to half its effective size.
* @param a_vec Vector to be compressed.
* @tparam ADataType The data type of a_vec
* @tparam CompressedSize The target compression size
* @tparam AVec The vector type of a_vec (deduced)
* @return Packed 32bit word containing **CompressedSize** 2bit fields.
* Each field encodes the original position (03) of the corresponding
* nonzero element in the input. If fewer than CompressedSize
* nonzeros are found, remaining fields default to 2 (see below).
*/
template <typename ADataType, index_t CompressedSize, typename AVec>
static CK_TILE_DEVICE int32_t compress_a_impl(AVec& a_vec)
{
// idx holds one 2bit index per output element (total CompressedSize entries).
// It is initialized to the pattern 0b10 for every field. This matches
// what the hardware expects when there are fewer than two nonzero values
// in a 4element group the unused output is treated as coming from slot 2.
// The loop below will clear and set each field as real nonzeros are seen.
int32_t idx = 0;
static_for<0, CompressedSize, 1>{}([&](auto k) { idx |= (2 << (2 * k)); });
static_for<0, CompressedSize / 2, 1>{}([&](auto i) {
ADataType nonzero_elems[2] = {a_vec[i * 4 + 2], a_vec[i * 4 + 3]};
int32_t non_zero_pos = 0;
static_for<0, 3, 1>{}([&](auto j) {
if(a_vec[i * 4 + j] != 0.0f)
{
nonzero_elems[non_zero_pos] = a_vec[i * 4 + j];
// clear the twobit field for this output and insert j
idx &= ~(0b11 << 2 * (i * 2 + non_zero_pos));
idx |= j << 2 * (i * 2 + non_zero_pos);
++non_zero_pos;
}
});
a_vec[i * 2] = nonzero_elems[0];
a_vec[i * 2 + 1] = nonzero_elems[1];
});
return idx;
}
template <typename WarpGemmAttribute_>
struct WarpGemmSmfmacImpl
{
@@ -86,10 +43,37 @@ struct WarpGemmSmfmacImpl
return WarpGemmAttribute_::get_num_of_access();
}
template <index_t CompressedSize, typename AVec>
CK_TILE_DEVICE int32_t compress_a_vec(AVec& a_vec)
//----------------------------------------------------------------------------------------------
/// @brief Compress A vector for 2:4 structured sparsity instruction by moving all non-zero
/// elements into lower part of a_vec to half its effective size.
///
/// @param a_vec Vector to be compressed.
///
/// @return Four 2-bit indexes of non-zero elements locations
///
template <typename AVec>
CK_TILE_DEVICE int32_t compress_a(AVec& a_vec) const
{
return compress_a_impl<ADataType, CompressedSize>(a_vec);
int32_t idx = 0b11101110;
static_for<0, 2, 1>{}([&](auto i) {
ADataType nonzero_elems[2] = {a_vec[i * 4 + 2], a_vec[i * 4 + 3]};
int32_t non_zero_pos = 0;
static_for<0, 3, 1>{}([&](auto j) {
if(a_vec[i * 4 + j] != 0.0f)
{
nonzero_elems[non_zero_pos] = a_vec[i * 4 + j];
idx &= ~(0b11 << 2 * (i * 2 + non_zero_pos));
idx |= j << 2 * (i * 2 + non_zero_pos);
++non_zero_pos;
}
});
a_vec[i * 2] = nonzero_elems[0];
a_vec[i * 2 + 1] = nonzero_elems[1];
});
return idx;
}
template <typename CTensor, typename ATensor, typename BTensor, bool post_nop_ = false>
@@ -102,11 +86,10 @@ struct WarpGemmSmfmacImpl
constexpr auto CompressionRatio = WarpGemmAttribute::kCompressionRatio;
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
static constexpr index_t CompressedSize =
ATensor::get_thread_buffer_size() / CompressionRatio;
using AVecCompressed = ext_vector_t<ADataType, CompressedSize>;
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
using AVecCompressed =
ext_vector_t<ADataType, ATensor::get_thread_buffer_size() / CompressionRatio>;
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
constexpr auto I0 = number<0>{};
@@ -114,9 +97,8 @@ struct WarpGemmSmfmacImpl
const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
const int32_t idx = compress_a_vec<CompressedSize>(a_vec);
const int32_t idx = compress_a(a_vec);
static_assert(CompressedSize == 4);
// @TODO can we simply set a_vec_pruned to a_vec[0:3]?
const AVecCompressed a_vec_pruned = {a_vec[0], a_vec[1], a_vec[2], a_vec[3]};