[rocm-libraries] ROCm/rocm-libraries#6212 (commit ccee58d)

=?UTF-8?q?[CK=20TILE]=20Unification=20Work=20=E2=80=93=20?=
 =?UTF-8?q?More=20accurate=20tests=20for=20MmaPipelines=20(#6212)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

This PR solves several issues:

#### More accurate tests for MmaPipelines

The current tests for the MmaPipelines (test_amdgcn_sparse_mma,
test_amdgcn_wavewise_mma) use explicit input fragment vectors filled
with 1s, and only check the output of a single lane. We should have
tests that actually use the MmaPipelines with non-trivial input matrices
and verify the complete output.
Some other aspects of the current MmaPipelines tests that I noticed and
deserve some attention:

1. There is sometimes iteration over K outside of the pipeline, which is
then included in WaveTileK or FragK, which is not correct. We should
remove it, move K iteration inside of the pipeline, or be more clear
about this outer-K loop size and how it propagates downwards.
2. There is very tight coupling between the kernel, gtest code, and
test_pipeline helper, requiring a lot of information and functions to be
passed back and forth.
3. The test_pipeline helper is doing a bunch of register-related logic
on the host (related to point 1)
4. Without this register logic the only thing it does is check the
device, call the kernel, and check the output, but with a lot of
boilerplate.

#### Test helper for detecting target arch at HOST runtime

There is a really apparent issue we faced while writing tests:

Scenario:
1. Compile a test that supports both gfx950 and gfx1201 for gfx950
2. Run the test on a server that only has gfx1201 GPU

Actual:
Segmentation fault

Expected:
The test can correctly detect from HOST runtime that the DEVICE
target_id was different and skips the test.

Notes:

The only way of detecting the COMPILER_TARGET_ID in the existing "arch"
framework is launching a kernel and calling `get_compiler_target()` (so,
from a DEVICE code). This will create a segmentation fault if the
current arch differs from the target arch. To cope with this issue, we
propose to export the compiler target(s) (note they can be many) through
`projects/composablekernel/test/ck_tile/core/arch/CMakeLists.txt` and
define a test helper to deal with such cases.

#### Add composition support to Transforms

We have a small number of Transforms which act on MmaOp input and output
data, before and after the MmaOp call respectively. These are currently
implemented to work on an MmaTile level, but in theory they are also
supposed to work at a WaveTile level, i.e. after composition of multiple
MmaTiles to create larger effective MNK dimensions. Currently the
composed MmaTiles look like 2D C-style arrays of the individual MmaTile
level register vectors (see WaveWiseMmaPipeline). The transforms should
be able to take these and perform the proper transforms to the whole
WaveTile at once. This might allow for better performing
transformations.

Note: This PR handles the SparseTransform case and if we don't end up
doing scale as a transformation, there isn't really much left to do. If
we end up having only the sparse transform as a non-trivial transform,
then we could also consider removing the Transform framework.
This commit is contained in:
chris-tsiaousis-hpc
2026-06-03 14:35:18 +00:00
committed by assistant-librarian[bot]
parent 88f8d24c34
commit db05d61136
20 changed files with 1646 additions and 589 deletions

View File

@@ -24,6 +24,7 @@
#include "ck_tile/core/arch/mma/mfma/mfma_selector.hpp"
#include "ck_tile/core/arch/mma/mfma/mfma_traits.hpp"
#include "ck_tile/core/arch/mma/mfma/mfma_transforms.hpp"
#include "ck_tile/core/arch/mma/mma.hpp"
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
#include "ck_tile/core/arch/mma/mma_pipeline.hpp"
#include "ck_tile/core/arch/mma/mma_selector.hpp"

View File

@@ -393,10 +393,3 @@ CK_TILE_HOST_DEVICE void print(amdgcn_mma<ADataType,
#if __clang_major__ >= 23
#pragma clang diagnostic pop
#endif
// Include the implementations
#include "wmma/wmma.hpp" // should be included before the below headers
#include "mfma/mfma.hpp"
#include "scale/scale.hpp"
#include "sparse/sparse.hpp"

View File

@@ -6,6 +6,7 @@
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
#include "ck_tile/core/arch/mma/mma_selector.hpp"
#include "ck_tile/core/arch/mma/mma_traits.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/mma_transforms.hpp"
namespace ck_tile::core::arch::mma {

View File

@@ -0,0 +1,8 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "wmma/wmma.hpp"
#include "mfma/mfma.hpp"
#include "sparse/sparse.hpp"

View File

@@ -270,12 +270,20 @@ struct MmaPipelineBase
{
if constexpr(MmaOpTraits<typename Derived::MmaOp>::IsSupported)
{
auto transformed_inputs = applyTransformsToInputs(
hasFlag<MmaPipelineOptionFlag::ABSwap>() ? std::forward<VecTB>(b)
: std::forward<VecTA>(a),
hasFlag<MmaPipelineOptionFlag::ABSwap>() ? std::forward<VecTA>(a)
: std::forward<VecTB>(b),
std::forward<VecTC>(accum));
constexpr bool swap_a_and_b = hasFlag<MmaPipelineOptionFlag::ABSwap>();
auto transformed_inputs = [&]() {
if constexpr(swap_a_and_b)
{
return applyTransformsToInputs(
std::forward<VecTB>(b), std::forward<VecTA>(a), std::forward<VecTC>(accum));
}
else
{
return applyTransformsToInputs(
std::forward<VecTA>(a), std::forward<VecTB>(b), std::forward<VecTC>(accum));
}
}();
Derived::execImpl(transformed_inputs);
@@ -302,17 +310,18 @@ struct MmaPipelineBase
{
if constexpr(MmaOpTraits<typename Derived::MmaOp>::IsSupported)
{
// TODO: c++20: Call template functions with MmaPipelineOptionFlags directly
static_assert(MmaOpTraits<typename Derived::MmaOp>::IsScale,
"This exec variant is intended for scale policy structs");
constexpr bool swap_a_and_b = hasFlag<MmaPipelineOptionFlag::ABSwap>();
auto transformed_inputs = applyTransformsToInputs(
hasFlag<MmaPipelineOptionFlag::ABSwap>() ? std::forward<VecTB>(b)
: std::forward<VecTA>(a),
hasFlag<MmaPipelineOptionFlag::ABSwap>() ? std::forward<VecTA>(a)
: std::forward<VecTB>(b),
swap_a_and_b ? std::forward<VecTB>(b) : std::forward<VecTA>(a),
swap_a_and_b ? std::forward<VecTA>(a) : std::forward<VecTB>(b),
std::forward<VecTC>(accum),
hasFlag<MmaPipelineOptionFlag::ABSwap>() ? std::forward<ScaleBDataType>(scale_B)
: std::forward<ScaleADataType>(scale_A),
hasFlag<MmaPipelineOptionFlag::ABSwap>() ? std::forward<ScaleADataType>(scale_A)
: std::forward<ScaleBDataType>(scale_B));
swap_a_and_b ? std::forward<ScaleBDataType>(scale_B)
: std::forward<ScaleADataType>(scale_A),
swap_a_and_b ? std::forward<ScaleADataType>(scale_A)
: std::forward<ScaleBDataType>(scale_B));
Derived::execImpl(transformed_inputs);

View File

@@ -169,7 +169,7 @@ struct WaveWiseMmaPipeline : public MmaPipelineBase<dense::wavewise::detail::get
}
else
{
static_assert(false);
static_assert(false, "Invalid accumulation policy");
}
}
};

View File

@@ -5,6 +5,7 @@
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
#include "ck_tile/core/arch/mma/mma_pipeline.hpp"
#include "ck_tile/core/arch/mma/mma_wavewise.hpp"
#include "ck_tile/core/arch/mma/scale/scale_selector.hpp"
#include "ck_tile/core/arch/mma/scale/scale_transforms.hpp"
#include "ck_tile/core/config.hpp"
@@ -16,12 +17,32 @@
namespace ck_tile::core::arch::mma {
/**
* @class ScaleMmaPipeline
* @brief Driver for the wave-tile scale Mma operation. Given a backend MmaOp implementation
* (e.g., scale MFMA), this class performs fragment-wise (MmaTile) decomposition to
* matrix-multiply input WaveTiles of (A: WaveTileM x WaveTileK) x (B: WaveTileK x WaveTileN) and
* accumulates results into output WaveTile (C: WaveTileM x WaveTileN).
* Like WaveWiseMmaPipeline, this decomposes WaveTile dimensions into fragments and iterates
* internally over FragsM x FragsN x FragsK, passing per-wave scale factors to each fragment call.
* @tparam ADataType Data type of input WaveTile A
* @tparam BDataType Data type of input WaveTile B
* @tparam CDataType Data type of input/output WaveTile C (accumulator)
* @tparam WaveTileM Mma WaveTile M dimension
* @tparam WaveTileN Mma WaveTile N dimension
* @tparam WaveTileK Mma WaveTile K dimension
* @tparam AccumPolicy The fragment order of the accum. registers (row or col major frag order)
* @tparam CompilerTarget The compiler target
* @tparam MmaOp_ Backend wrapper class that will perform the mma op
* @tparam MmaTransforms The set of transforms to be applied to input/output WaveTiles
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
std::uint32_t FragM,
std::uint32_t FragN,
std::uint32_t FragK,
std::uint32_t WaveTileM,
std::uint32_t WaveTileN,
std::uint32_t WaveTileK,
MmaAccumPolicy AccumPolicy = MmaAccumPolicy::ROW_MAJOR,
typename CompilerTarget =
decltype(get_compiler_target()), // TODO: c++20 amdgcn_target_arch_id GfxTargetId =
// get_compiler_target(),
@@ -30,37 +51,55 @@ template <typename ADataType,
// MmaDefaultSelector<ADataType,
BDataType,
CDataType,
FragM,
FragN,
FragK,
WaveTileM,
WaveTileN,
WaveTileK,
CompilerTarget,
MmaOpFamily::SCALE>::SelectedOp,
typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms =
typename MmaTransformsDefaultSelector<MmaOp_, CompilerTarget>::SelectedTransforms>
// clang-format off
struct ScaleMmaPipeline : public MmaPipelineBase<static_cast<int>(MmaPipelineOptionFlag::NONE), ScaleMmaPipeline<ADataType, BDataType, CDataType, FragM, FragN, FragK, CompilerTarget, MmaOp_, MmaTransforms>>
struct ScaleMmaPipeline : public MmaPipelineBase<static_cast<int>(MmaPipelineOptionFlag::NONE), ScaleMmaPipeline<ADataType, BDataType, CDataType, WaveTileM, WaveTileN, WaveTileK, AccumPolicy, CompilerTarget, MmaOp_, MmaTransforms>>
{
using Base = MmaPipelineBase<static_cast<int>(MmaPipelineOptionFlag::NONE), ScaleMmaPipeline<ADataType, BDataType, CDataType, FragM, FragN, FragK, CompilerTarget, MmaOp_, MmaTransforms>>;
using Base = MmaPipelineBase<static_cast<int>(MmaPipelineOptionFlag::NONE), ScaleMmaPipeline<ADataType, BDataType, CDataType, WaveTileM, WaveTileN, WaveTileK, AccumPolicy, CompilerTarget, MmaOp_, MmaTransforms>>;
// clang-format on
using MmaOp = MmaOp_; // Expose the selected MmaOp
// Expose caller-side vector types
using AVecType = typename MmaOp::AVecType;
using BVecType = typename MmaOp::BVecType;
using CVecType = typename MmaOp::CVecType;
// Fragment dimensions (from the hardware MmaOp)
constexpr static uint32_t FragM = MmaOp::kM;
constexpr static uint32_t FragN = MmaOp::kN;
constexpr static uint32_t FragK = MmaOp::kK;
// Expose internal vector types
// Fragment counts for decomposition
constexpr static uint32_t FragsM = WaveTileM / FragM;
constexpr static uint32_t FragsN = WaveTileN / FragN;
constexpr static uint32_t FragsK = WaveTileK / FragK;
// Vector types for packed registers in each fragment
using InternalAVecT = typename MmaOp::AVecType;
using InternalBVecT = typename MmaOp::BVecType;
using InternalCVecT = typename MmaOp::CVecType;
// Buffer types for WaveTiles
using AVecType = InternalAVecT[FragsM][FragsK];
using BVecType = InternalBVecT[FragsN][FragsK];
using CVecType = InternalCVecT[FragsM][FragsN];
// Transforms
using ATransform = typename MmaTransforms::ATransform;
using BTransform = typename MmaTransforms::BTransform;
using CTransform = typename MmaTransforms::CTransform;
using DTransform = typename MmaTransforms::DTransform;
// Sanity checks
static_assert(WaveTileM >= FragM, "WaveTileM must be >= FragM");
static_assert(WaveTileN >= FragN, "WaveTileN must be >= FragN");
static_assert(WaveTileK >= FragK, "WaveTileK must be >= FragK");
static_assert(WaveTileM % FragM == 0u, "WaveTileM must be a multiple of FragM");
static_assert(WaveTileN % FragN == 0u, "WaveTileN must be a multiple of FragN");
static_assert(WaveTileK % FragK == 0u, "WaveTileK must be a multiple of FragK");
template <typename VecTA,
typename VecTB,
typename VecTC,
@@ -69,8 +108,40 @@ struct ScaleMmaPipeline : public MmaPipelineBase<static_cast<int>(MmaPipelineOpt
CK_TILE_DEVICE static void
execImpl(std::tuple<VecTA, VecTB, VecTC, ScaleADataType, ScaleBDataType>& vecs)
{
auto& [a_vec, b_vec, c_vec, scale_A, scale_B] = vecs;
c_vec = MmaOp::exec(a_vec, b_vec, c_vec, scale_A, scale_B);
auto& [a_frag, b_frag, c_frag, scale_A, scale_B] = vecs;
if constexpr(AccumPolicy == MmaAccumPolicy::ROW_MAJOR)
{
for(uint32_t bm = 0u; bm < FragsM; ++bm)
{
for(uint32_t bn = 0u; bn < FragsN; ++bn)
{
for(uint32_t bk = 0u; bk < FragsK; ++bk)
{
c_frag[bm][bn] = MmaOp::exec(
a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn], scale_A, scale_B);
}
}
}
}
else if constexpr(AccumPolicy == MmaAccumPolicy::COL_MAJOR)
{
for(uint32_t bn = 0u; bn < FragsN; ++bn)
{
for(uint32_t bm = 0u; bm < FragsM; ++bm)
{
for(uint32_t bk = 0u; bk < FragsK; ++bk)
{
c_frag[bm][bn] = MmaOp::exec(
a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn], scale_A, scale_B);
}
}
}
}
else
{
static_assert(false, "Invalid accumulation policy");
}
}
};

View File

@@ -5,10 +5,10 @@
#include "ck_tile/core/arch/mma/mma_pipeline.hpp"
#include "ck_tile/core/arch/mma/mma_selector.hpp"
#include "ck_tile/core/arch/mma/mma_traits.hpp"
#include "ck_tile/core/arch/mma/mma_wavewise.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse_transforms.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include <cstdint>
#include <type_traits>
namespace ck_tile::core::arch::mma {
@@ -20,12 +20,33 @@ constexpr inline int getPipelineFlags()
}
} // namespace sparse::detail
/**
* @class SparseMmaPipeline
* @brief Driver for the wave-tile sparse Mma operation. Given a backend MmaOp implementation
* (e.g., smfmac), this class performs fragment-wise (MmaTile) decomposition to matrix-multiply
* input WaveTiles of (A: WaveTileM x WaveTileK) x (B: WaveTileK x WaveTileN) and accumulates
* results into output WaveTile (C: WaveTileM x WaveTileN).
* Like WaveWiseMmaPipeline, this decomposes WaveTile dimensions into fragments and iterates
* internally over FragsM × FragsN × FragsK. The A operand is provided in uncompressed form;
* 2:4 structured sparsity compression (SparseCompressTransform) is applied.
* @tparam ADataType Data type of input WaveTile A
* @tparam BDataType Data type of input WaveTile B
* @tparam CDataType Data type of input/output WaveTile C (accumulator)
* @tparam WaveTileM Mma WaveTile M dimension
* @tparam WaveTileN Mma WaveTile N dimension
* @tparam WaveTileK Mma WaveTile K dimension
* @tparam AccumPolicy The fragment order of the accum. registers (row or col major frag order)
* @tparam CompilerTarget The compiler target
* @tparam MmaOp_ Backend wrapper class that will perform the mma op
* @tparam MmaTransforms The set of transforms to be applied to input/output WaveTiles
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
uint32_t FragM,
uint32_t FragN,
uint32_t FragK,
uint32_t WaveTileM,
uint32_t WaveTileN,
uint32_t WaveTileK,
MmaAccumPolicy AccumPolicy = MmaAccumPolicy::ROW_MAJOR,
typename CompilerTarget =
decltype(get_compiler_target()), // TODO: c++20 amdgcn_target_arch_id GfxTargetId =
// get_compiler_target(),
@@ -34,17 +55,17 @@ template <typename ADataType,
// MmaDefaultSelector<ADataType,
BDataType,
CDataType,
FragM,
FragN,
FragK,
WaveTileM,
WaveTileN,
WaveTileK,
CompilerTarget,
MmaOpFamily::SPARSE>::SelectedOp,
typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms =
typename MmaTransformsDefaultSelector<MmaOp_, CompilerTarget>::SelectedTransforms>
// clang-format off
struct SparseMmaPipeline : public MmaPipelineBase<sparse::detail::getPipelineFlags(), SparseMmaPipeline<ADataType, BDataType, CDataType, FragM, FragN, FragK, CompilerTarget, MmaOp_, MmaTransforms>>
struct SparseMmaPipeline : public MmaPipelineBase<sparse::detail::getPipelineFlags(), SparseMmaPipeline<ADataType, BDataType, CDataType, WaveTileM, WaveTileN, WaveTileK, AccumPolicy, CompilerTarget, MmaOp_, MmaTransforms>>
{
using Base = MmaPipelineBase<sparse::detail::getPipelineFlags(), SparseMmaPipeline<ADataType, BDataType, CDataType, FragM, FragN, FragK, CompilerTarget, MmaOp_, MmaTransforms>>;
using Base = MmaPipelineBase<sparse::detail::getPipelineFlags(), SparseMmaPipeline<ADataType, BDataType, CDataType, WaveTileM, WaveTileN, WaveTileK, AccumPolicy, CompilerTarget, MmaOp_, MmaTransforms>>;
// clang-format on
static_assert(!Base::template hasFlag<MmaPipelineOptionFlag::ABSwap>(),
@@ -52,48 +73,153 @@ struct SparseMmaPipeline : public MmaPipelineBase<sparse::detail::getPipelineFla
using MmaOp = MmaOp_; // Expose the selected MmaOp
// Calculate the uncompressed A vector type
// Fragment dimensions (from the hardware MmaOp)
constexpr static uint32_t FragM = MmaOp::kM;
constexpr static uint32_t FragN = MmaOp::kN;
constexpr static uint32_t FragK = MmaOp::kK;
// Fragment counts for decomposition
constexpr static uint32_t FragsM = WaveTileM / FragM;
constexpr static uint32_t FragsN = WaveTileN / FragN;
constexpr static uint32_t FragsK = WaveTileK / FragK;
// Calculate the uncompressed external A per-fragment vector type
struct ExternalAVecCalculator
{
using AVecTraits = vector_traits<typename MmaOp::AVecType>;
static constexpr index_t ASize = AVecTraits::vector_size * MmaOp::kCompressionRatio;
using AVecType = ext_vector_t<typename AVecTraits::scalar_type, ASize>;
};
using ExternalAFragVecT = typename ExternalAVecCalculator::AVecType;
// Expose caller-side vector types
using AVecType = typename ExternalAVecCalculator::AVecType;
using BVecType = typename MmaOp::BVecType;
using CVecType = typename MmaOp::CVecType;
// Scalar type of A
using AScalarT = typename ExternalAVecCalculator::AVecTraits::scalar_type;
// Expose internal vector types
using InternalAVecT = typename MmaOp::AVecType;
// Per-fragment sizes
static constexpr uint32_t ExternalAFragSize = ExternalAVecCalculator::ASize;
static constexpr uint32_t InternalAFragSize =
vector_traits<typename MmaOp::AVecType>::vector_size;
// Full wave-tile sizes (all fragments combined)
static constexpr uint32_t TotalUncompressedElems = FragsM * FragsK * ExternalAFragSize;
static constexpr uint32_t TotalCompressedElems =
TotalUncompressedElems / MmaOp::kCompressionRatio;
// Variable-length idx type for the whole wave-tile (spans multiple int32_t words if needed)
static constexpr index_t IdxNumWords = sparse::detail::idx_words_needed<TotalCompressedElems>;
using IdxType = sparse::detail::SparseIdxPack<IdxNumWords>;
// Per-fragment compressed vector type (for individual MmaOp::exec calls)
using FragAVecT = typename MmaOp::AVecType;
// Internal vector types used by the base class formatBuffer.
// InternalAVecT matches the full compressed wave-tile so the base class can
// format the SparseCompressTransform result via formatBuffer<InternalAVecT>.
using InternalAVecT = ext_vector_t<AScalarT, TotalCompressedElems>;
using InternalBVecT = typename MmaOp::BVecType;
using InternalCVecT = typename MmaOp::CVecType;
// Buffer types for WaveTiles (caller-facing).
// A is a single flat uncompressed vector covering the whole wave-tile.
// The base class compresses it in one pass via ATransform.
using AVecType = ext_vector_t<AScalarT, TotalUncompressedElems>;
using BVecType = InternalBVecT[FragsN][FragsK];
using CVecType = InternalCVecT[FragsM][FragsN];
// Transforms
using ATransform = typename MmaTransforms::ATransform;
using BTransform = typename MmaTransforms::BTransform;
using CTransform = typename MmaTransforms::CTransform;
using DTransform = typename MmaTransforms::DTransform;
// Sanity checks
static_assert(WaveTileM >= FragM, "WaveTileM must be >= FragM");
static_assert(WaveTileN >= FragN, "WaveTileN must be >= FragN");
static_assert(WaveTileK >= FragK, "WaveTileK must be >= FragK");
static_assert(WaveTileM % FragM == 0u, "WaveTileM must be a multiple of FragM");
static_assert(WaveTileN % FragN == 0u, "WaveTileN must be a multiple of FragN");
static_assert(WaveTileK % FragK == 0u, "WaveTileK must be a multiple of FragK");
template <typename ATransformResult, typename BTransformResult, typename CTransformResult>
CK_TILE_DEVICE static void
execImpl(std::tuple<ATransformResult, BTransformResult, CTransformResult>& vecs)
execImpl(std::tuple<ATransformResult, BTransformResult, CTransformResult>& transformedInputs)
{
auto& [a, b_frag, c_frag] = transformedInputs;
auto& [a_compressed_whole, idx] = a;
// Validate that the ATransform result and per-fragment reinterpretation are correct
checkATransformResult<ATransformResult>();
auto& [a_result, b_vec, c_vec] = vecs;
auto& [a_vec, idx] = a_result;
c_vec = MmaOp::exec(a_vec, b_vec, c_vec, idx);
// Reinterpret the full compressed vector as per-fragment arrays
auto* a_frags = ck_tile::bit_cast<FragAVecT(*)[FragsK]>(&a_compressed_whole);
// Accumulation loop with per-fragment idx extraction
if constexpr(AccumPolicy == MmaAccumPolicy::ROW_MAJOR)
{
for(uint32_t bm = 0u; bm < FragsM; ++bm)
{
for(uint32_t bn = 0u; bn < FragsN; ++bn)
{
for(uint32_t bk = 0u; bk < FragsK; ++bk)
{
c_frag[bm][bn] = MmaOp::exec(
a_frags[bm][bk],
b_frag[bn][bk],
c_frag[bm][bn],
sparse::detail::extract_fragment_idx<InternalAFragSize, FragsK>(
idx, bm, bk));
}
}
}
}
else if constexpr(AccumPolicy == MmaAccumPolicy::COL_MAJOR)
{
for(uint32_t bn = 0u; bn < FragsN; ++bn)
{
for(uint32_t bm = 0u; bm < FragsM; ++bm)
{
for(uint32_t bk = 0u; bk < FragsK; ++bk)
{
c_frag[bm][bn] = MmaOp::exec(
a_frags[bm][bk],
b_frag[bn][bk],
c_frag[bm][bn],
sparse::detail::extract_fragment_idx<InternalAFragSize, FragsK>(
idx, bm, bk));
}
}
}
}
else
{
static_assert(false, "Invalid accumulation policy");
}
}
private:
// Type check helper - not a device function, so std::declval is available
// Compile-time validation of ATransform result and per-fragment reinterpretation.
// Ensures the compressed vector returned by ATransform::exec can be safely
// reinterpreted as FragAVecT[FragsM][FragsK] for per-fragment MmaOp dispatch.
template <typename ATransformResult>
static constexpr void checkATransformResult()
{
using ExternalAvecRef = std::add_lvalue_reference_t<AVecType>;
static_assert(std::is_same_v<ATransformResult,
decltype(ATransform::exec(std::declval<ExternalAvecRef>()))>);
decltype(ATransform::exec(std::declval<ExternalAvecRef>()))>,
"ATransformResult must match the return type of ATransform::exec");
using CompressedVecType =
std::remove_reference_t<std::tuple_element_t<0, ATransformResult>>;
static_assert(sizeof(CompressedVecType) == sizeof(FragAVecT) * FragsM * FragsK,
"Compressed A vector size must equal sizeof(FragAVecT[FragsM][FragsK])");
static_assert(alignof(CompressedVecType) >= alignof(FragAVecT),
"Compressed vector alignment must be >= FragAVecT alignment "
"for safe reinterpret_cast to per-fragment array");
using ActualIdxType = std::tuple_element_t<1, ATransformResult>;
static_assert(std::is_same_v<ActualIdxType, IdxType>,
"Sparsity index type must match SparseIdxPack<IdxNumWords>");
}
};

View File

@@ -13,6 +13,30 @@
namespace ck_tile::core::arch::mma {
namespace sparse::detail {
/// Number of int32_t words needed to store CompressedSize 2-bit idx fields.
template <index_t CompressedSize>
static constexpr index_t idx_words_needed = (CompressedSize * 2 + 31) / 32;
/**
* @class SparseIdxPack
* @brief Variable-length container for 2:4 structured sparsity index metadata.
*
* Each compressed element produces a 2-bit index field encoding the original
* position (03) within its group of 4. When composing multiple MMA fragments
* in M and K dimensions within a WaveTile, the total number of index bits can
* exceed 32. This struct packs the index fields into an array of int32_t words,
* sized at compile time.
*
* @tparam NumWords Number of int32_t words needed to store all index fields.
*/
template <index_t NumWords>
struct SparseIdxPack
{
static_assert(NumWords > 0, "SparseIdxPack requires at least 1 word");
int32_t words[NumWords] = {};
};
/**
* @brief Compress A vector for 2:4 structured sparsity instruction by moving all non-zero
* elements into lower part of a_vec to half its effective size.
@@ -20,21 +44,29 @@ namespace sparse::detail {
* @tparam ADataType The data type of a_vec
* @tparam CompressedSize The target compression size
* @tparam AVec The vector type of a_vec (deduced)
* @return Packed 32bit word containing **CompressedSize** 2bit fields.
* Each field encodes the original position (03) of the corresponding
* nonzero element in the input. If fewer than CompressedSize
* nonzeros are found, remaining fields default to 2 (see below).
* @return SparseIdxPack containing **CompressedSize** 2bit fields packed
* across one or more int32_t words. Each field encodes the original
* position (03) of the corresponding nonzero element in the input.
* If fewer than CompressedSize nonzeros are found, remaining fields
* default to 2 (see below).
*/
template <typename ADataType, index_t CompressedSize, typename AVec>
static CK_TILE_DEVICE int32_t compress_a_impl(AVec& a_vec)
static CK_TILE_DEVICE auto compress_a_impl(AVec& a_vec)
{
// idx holds one 2bit index per output element (total CompressedSize entries).
static constexpr index_t NumIdxWords = idx_words_needed<CompressedSize>;
// idx holds one 2bit index per output element (total CompressedSize entries),
// packed across NumIdxWords int32_t words.
// It is initialized to the pattern 0b10 for every field. This matches
// what the hardware expects when there are fewer than two nonzero values
// in a 4element group the unused output is treated as coming from slot 2.
// The loop below will clear and set each field as real nonzeros are seen.
int32_t idx = 0;
static_for<0, CompressedSize, 1>{}([&](auto k) { idx |= (2u << (2u * k)); });
SparseIdxPack<NumIdxWords> idx{};
static_for<0, CompressedSize, 1>{}([&](auto k) {
constexpr uint32_t bit_pos = static_cast<uint32_t>(k) * 2u;
constexpr uint32_t word = bit_pos / 32u;
constexpr uint32_t shift = bit_pos % 32u;
idx.words[word] |= static_cast<int32_t>(2u << shift);
});
static_for<0, CompressedSize / 2, 1>{}([&](auto i) {
ADataType nonzero_elems[2] = {a_vec[i * 4 + 2], a_vec[i * 4 + 3]};
@@ -45,8 +77,13 @@ static CK_TILE_DEVICE int32_t compress_a_impl(AVec& a_vec)
{
nonzero_elems[non_zero_pos] = a_vec[i * 4 + j];
// clear the twobit field for this output and insert j
idx &= ~(0b11u << (2u * (i * 2 + non_zero_pos)));
idx |= static_cast<uint32_t>(j) << (2u * (i * 2 + non_zero_pos));
const uint32_t field_idx =
static_cast<uint32_t>(i) * 2u + static_cast<uint32_t>(non_zero_pos);
const uint32_t bit_pos = field_idx * 2u;
const uint32_t word = bit_pos / 32u;
const uint32_t shift = bit_pos % 32u;
idx.words[word] &= ~static_cast<int32_t>(0b11u << shift);
idx.words[word] |= static_cast<int32_t>(static_cast<uint32_t>(j) << shift);
++non_zero_pos;
}
});
@@ -56,6 +93,40 @@ static CK_TILE_DEVICE int32_t compress_a_impl(AVec& a_vec)
return idx;
}
/**
* @brief Extract the per-fragment sparsity index from a packed idx pack.
* After whole-wave-tile compression, the returned idx packs 2-bit fields for
* every compressed output element across one or more int32_t words.
* @return A single int32_t with this fragment's 2-bit fields at the
* least-significant positions, suitable for passing to the MMA builtin.
*/
template <uint32_t FragCompressedSize, uint32_t FragsK, index_t NumIdxWords>
static CK_TILE_DEVICE int32_t extract_fragment_idx(const SparseIdxPack<NumIdxWords>& idx,
uint32_t m,
uint32_t k)
{
static constexpr uint32_t IdxBitsPerFrag = FragCompressedSize * 2;
const auto fragLinearIdx = m * FragsK + k;
const auto totalBitOffset = fragLinearIdx * IdxBitsPerFrag;
const auto wordIdx = totalBitOffset / 32u;
const auto bitInWord = totalBitOffset % 32u;
uint32_t result = static_cast<uint32_t>(idx.words[wordIdx]) >> bitInWord;
// If fragment bits span a word boundary, stitch in bits from the next word.
// (This is a safety measure; it should not occur when IdxBitsPerFrag is a
// power-of-2 divisor of 32, which is always the case for current MMA ops.)
if constexpr(NumIdxWords > 1)
{
if(bitInWord != 0 && bitInWord + IdxBitsPerFrag > 32u)
{
result |= static_cast<uint32_t>(idx.words[wordIdx + 1]) << (32u - bitInWord);
}
}
return static_cast<int32_t>(result);
}
} // namespace sparse::detail
/**
@@ -75,15 +146,15 @@ struct SparseCompressTransform
static constexpr auto VecN = VecTraits::vector_size;
static constexpr index_t CompressedSize = VecN / CompressionRatio;
using VecCompressed = ext_vector_t<ScalarT, CompressedSize>;
using IdxType =
sparse::detail::SparseIdxPack<sparse::detail::idx_words_needed<CompressedSize>>;
static_assert(VecN % CompressionRatio == 0, "VecN must be divisible by CompressionRatio");
static_assert(CompressedSize > 0, "CompressedSize must be > 0");
const auto idx = sparse::detail::compress_a_impl<ScalarT, CompressedSize>(v);
auto idx = sparse::detail::compress_a_impl<ScalarT, CompressedSize>(v);
// TODO c++20: Use bit_cast
return std::tuple<VecCompressed&, int32_t>(
*std::launder(reinterpret_cast<VecCompressed*>(&v)), idx);
return std::tuple<VecCompressed&, IdxType>(*ck_tile::bit_cast<VecCompressed*>(&v), idx);
}
};