[rocm-libraries] ROCm/rocm-libraries#4837 (commit 6316035)

[CK TILE] Unification of sparse MFMA/WMMA policy structs
 (#4837)

## Motivation

The existing unification work supports DENSE intrinsics. In this PR we
enable support for SPARSE as well as SCALE intrinsics and add an example
SPARSE implementation.

## Technical Details

Mostly trivial changes. One framework change is that the desired
`MmaOpFamily` is passed to the `MmaDefaultSelector`. As my relevant
commit explains, we do not support a fallback family at the moment, but
it is something we can consider.

## Test Plan

Added a new test for the relevant sparse specializations.

## Test Result

Test should pass.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
chris-tsiaousis-hpc
2026-03-05 19:53:16 +00:00
committed by assistant-librarian[bot]
parent 6e558658ea
commit 03ce21ddcb
23 changed files with 1173 additions and 89 deletions

View File

@@ -823,6 +823,12 @@ using enable_if_target_wave64_t =
#endif // __cplusplus <= 201703L
template <typename... Ts>
constexpr bool all_types_void = std::conjunction_v<std::is_same<void, Ts>...>;
template <typename... Enablers>
using enable_if_all = std::enable_if_t<all_types_void<Enablers...>>;
} // namespace core::arch
CK_TILE_HOST bool is_wave32()

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/core/utility/ignore.hpp"
@@ -82,11 +83,13 @@ template <typename ADataType,
uint32_t BlockK,
typename CtrlFlags,
typename CompilerTarget,
MmaOpFamily OpFamily_,
typename Enabler = void>
struct amdgcn_mma
{
// The base instance is unsupported because there is no __builtin to wrap.
using OpType = Unsupported;
using OpType = Unsupported;
static constexpr MmaOpFamily OpFamily = MmaOpFamily::UNDEFINED;
// Interface types for A, B, C vectors types
using AVecType = ext_vector_t<ADataType, 1>;
@@ -122,3 +125,4 @@ struct amdgcn_mma
// Include the implementations
#include "wmma/wmma.hpp"
#include "mfma/mfma.hpp"
#include "sparse/sparse.hpp"

View File

@@ -8,6 +8,7 @@
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
#include "ck_tile/core/arch/mma/mma_traits.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
namespace ck_tile::core::arch::mma {
@@ -68,10 +69,12 @@ struct amdgcn_mma<fp16_t,
16u,
CtrlFlags,
CompilerTarget,
MmaOpFamily::DENSE,
enable_if_target_family_gfx9_t<CompilerTarget>>
{
// Mfma operation type
using OpType = MfmaOp;
using OpType = MfmaOp;
static constexpr MmaOpFamily OpFamily = MmaOpFamily::DENSE;
// Register types
using AVecType = ext_vector_t<fp16_t, 4>;
@@ -125,9 +128,11 @@ struct amdgcn_mma<fp16_t,
32u,
CtrlFlags,
CompilerTarget,
MmaOpFamily::DENSE,
enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
{
using OpType = MfmaOp;
using OpType = MfmaOp;
static constexpr MmaOpFamily OpFamily = MmaOpFamily::DENSE;
// Packed register types
using AVecType = ext_vector_t<fp16_t, 8>;

View File

@@ -52,7 +52,8 @@ struct MfmaDefaultSelector
BlockN,
BlockKTest,
DefaultMfmaCtrlFlags, // By default, let's assume no special flags for MFMA
CompilerTarget>;
CompilerTarget,
MmaOpFamily::DENSE>;
using CandidateTraits = MmaOpTraits<CandidateOp>;
public:
@@ -98,7 +99,8 @@ struct MfmaDefaultSelector<ADataType, BDataType, CDataType, BlockM, BlockN, 1u,
BlockN,
1u,
DefaultMfmaCtrlFlags, // By default, let's assume no special flags for MFMA
CompilerTarget>;
CompilerTarget,
MmaOpFamily::DENSE>;
};
/**
@@ -114,6 +116,7 @@ struct MfmaDefaultSelector<ADataType, BDataType, CDataType, BlockM, BlockN, 1u,
* @tparam FragN Size of the N dimension of the fragment to decompose
* @tparam FragK Size of the K dimension of the fragment to decompose
* @tparam CompilerTarget The compiler target
* @tparam OpFamily The MMA operation family
*/
template <typename ADataType,
typename BDataType,
@@ -121,7 +124,8 @@ template <typename ADataType,
uint32_t FragM,
uint32_t FragN,
uint32_t FragK,
typename CompilerTarget> // TODO: c++20 amdgcn_target_arch_id CompilerTarget>
typename CompilerTarget,
MmaOpFamily OpFamily> // TODO: c++20 amdgcn_target_arch_id CompilerTarget>
struct MmaDefaultSelector<ADataType,
BDataType,
CDataType,
@@ -129,7 +133,9 @@ struct MmaDefaultSelector<ADataType,
FragN,
FragK,
CompilerTarget,
enable_if_target_family_gfx9_t<CompilerTarget>>
OpFamily,
enable_if_all<enable_if_target_family_gfx9_t<CompilerTarget>,
std::enable_if_t<OpFamily == MmaOpFamily::DENSE>>>
{
private:
// Provide the default depth-K search strategy for each class of common MFMA shapes.

View File

@@ -57,6 +57,7 @@ template <typename ADataType,
uint32_t FragM,
uint32_t FragN,
uint32_t FragK,
MmaOpFamily OpFamily,
MmaAccumPolicy AccumPolicy = MmaAccumPolicy::ROW_MAJOR,
typename CompilerTarget =
decltype(get_compiler_target()), // TODO: c++20 amdgcn_target_arch_id GfxTargetId =
@@ -69,7 +70,8 @@ template <typename ADataType,
FragM,
FragN,
FragK,
CompilerTarget>::SelectedOp,
CompilerTarget,
OpFamily>::SelectedOp,
typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms =
typename MmaTransformsDefaultSelector<MmaOp, CompilerTarget>::SelectedTransforms>
struct WaveWiseMma

View File

@@ -0,0 +1,48 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
namespace ck_tile::core::arch::mma {
/**
* @enum MmaOpFamily
* @brief Enumeration that defines mma op families and
*/
enum struct MmaOpFamily
{
UNDEFINED = 0,
DENSE,
SPARSE,
SCALE,
};
/**
* @class is_ctrl_fis_mma_op_of_familylag_of_family
* @brief Meta-function to check if MmaOp is of the specified MmaOpFamily
* @tparam Family Control flag family
* @tparam MmaOp amdgcn struct specialization type
*/
template <MmaOpFamily Family, typename MmaOp, typename = void>
struct is_mma_op_of_family : std::false_type
{
};
/**
* @struct is_mma_op_of_family
* @brief Specialization for Family == MmaOp::OpFamily detection
*/
template <MmaOpFamily Family, typename MmaOp>
struct is_mma_op_of_family<Family, MmaOp, std::enable_if_t<Family == MmaOp::OpFamily>>
: std::true_type
{
};
/**
* @brief Convenience evaluator for is_mma_op_of_family trait
* @tparam Family Desired control flag family
* @tparam MmaOp The amdgcn struct specialization type to check
*/
template <MmaOpFamily Family, typename MmaOp>
static constexpr bool is_mma_op_of_family_v = is_mma_op_of_family<Family, MmaOp>::value;
} // namespace ck_tile::core::arch::mma

View File

@@ -3,6 +3,7 @@
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
namespace ck_tile::core::arch::mma {
@@ -15,11 +16,12 @@ namespace ck_tile::core::arch::mma {
* architecture.
* @tparam ADataType Data type of matrix A
* @tparam BDataType Data type of matrix B
* @tparam CDataType Data type of the accumulator
* @tparam CDataType Data type of the accumulator
* @tparam FragM Fragment M dimension
* @tparam FragN Fragment N dimension
* @tparam FragK Fragment K dimension
* @tparam CompilerTarget The compiler target
* @tparam OpFamily The MMA operation family
* @tparam Enable SFINAE enabler
* @note Here we distinguish that Fragment MNK sizes from Block MNK sizes used in the actual MMA
* operation. Fragment sizes correspond to the overall tile size being computed, while Block sizes
@@ -34,14 +36,22 @@ template <typename ADataType,
uint32_t FragN,
uint32_t FragK,
typename CompilerTarget,
MmaOpFamily OpFamily,
typename Enable = void>
// TODO c++20 requires
struct MmaDefaultSelector
{
// By default, no selection is made, and we fall back to a pass-through unsupported
// implementation. This is because we do not have any knowledge of the target architecture here.
using SelectedOp =
amdgcn_mma<ADataType, BDataType, CDataType, FragM, FragN, FragK, void, amdgcn_target<>>;
using SelectedOp = amdgcn_mma<ADataType,
BDataType,
CDataType,
FragM,
FragN,
FragK,
void,
amdgcn_target<>,
MmaOpFamily::UNDEFINED>;
};
#if CK_TILE_CONCEPTS

View File

@@ -1,6 +1,8 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
#include "amdgcn_mma.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "mfma/mfma_traits.hpp"
@@ -69,6 +71,7 @@ concept MmaOpParamsI = requires(MmaOpParams op) {
{ MmaOpParams::BlockN } -> std::convertible_to<unsigned int>;
{ MmaOpParams::BlockK } -> std::convertible_to<unsigned int>;
{ MmaOpParams::GfxTargetId } -> std::convertible_to<amdgcn_target_arch_id>;
{ MmaOpParams::Family } -> std::convertible_to<MmaOpFamily>;
};
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
@@ -92,7 +95,8 @@ template <typename ADataType_,
uint32_t BlockN_,
uint32_t BlockK_,
typename CtrlFlags_,
typename CompilerTarget_>
typename CompilerTarget_,
MmaOpFamily OpFamily_>
// TODO: c++20 amdgcn_target_arch_id CompilerTarget_>
struct MmaOpParams<amdgcn_mma<ADataType_,
BDataType_,
@@ -101,17 +105,19 @@ struct MmaOpParams<amdgcn_mma<ADataType_,
BlockN_,
BlockK_,
CtrlFlags_,
CompilerTarget_>>
CompilerTarget_,
OpFamily_>>
{
// Capture incoming template parameters
using ADataType = ADataType_;
using BDataType = BDataType_;
using CDataType = CDataType_;
static constexpr uint32_t BlockM = BlockM_;
static constexpr uint32_t BlockN = BlockN_;
static constexpr uint32_t BlockK = BlockK_;
using CtrlFlags = CtrlFlags_;
using CompilerTarget = CompilerTarget_;
using ADataType = ADataType_;
using BDataType = BDataType_;
using CDataType = CDataType_;
static constexpr uint32_t BlockM = BlockM_;
static constexpr uint32_t BlockN = BlockN_;
static constexpr uint32_t BlockK = BlockK_;
using CtrlFlags = CtrlFlags_;
using CompilerTarget = CompilerTarget_;
static constexpr auto MmaOpFamily = OpFamily_;
// TODO c++20static constexpr amdgcn_target_arch_id GfxTargetId = CompilerTarget_;
};
@@ -131,6 +137,8 @@ struct MmaOpTraits : public MmaOpParams<MmaOp>
using BVecType = typename MmaOp::BVecType;
using CVecType = typename MmaOp::CVecType;
static constexpr MmaOpFamily OpFamily = MmaOp::OpFamily;
// Capture layout parameters
static constexpr index_t kAMBlock = MmaOp::kAMBlock;
static constexpr index_t kBNBlock = MmaOp::kBNBlock;
@@ -144,9 +152,13 @@ struct MmaOpTraits : public MmaOpParams<MmaOp>
static constexpr index_t kCM1PerLane = MmaOp::kCM1PerLane;
// Additional traits to identify the type of MmaOp at compile time
constexpr static bool IsMfma = is_mma_op_mfma_v<MmaOp>;
constexpr static bool IsWmma = is_mma_op_wmma_v<MmaOp>;
constexpr static bool IsSupported = is_mma_op_supported_v<MmaOp>;
constexpr static bool IsMfma = is_mma_op_mfma_v<MmaOp>;
constexpr static bool IsWmma = is_mma_op_wmma_v<MmaOp>;
constexpr static bool IsDense = OpFamily == MmaOpFamily::DENSE;
constexpr static bool IsSparse = OpFamily == MmaOpFamily::SPARSE;
constexpr static bool IsScale = OpFamily == MmaOpFamily::SCALE;
constexpr static bool IsSupported =
is_mma_op_supported_v<MmaOp> && OpFamily != MmaOpFamily::UNDEFINED;
};
} // namespace ck_tile::core::arch::mma

View File

@@ -0,0 +1,151 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
#include "ck_tile/core/arch/mma/mma_selector.hpp"
#include "ck_tile/core/arch/mma/mma_traits.hpp"
namespace ck_tile::core::arch::mma {
/**
* @class SparseMfmaDefaultSelector
* @brief Implements a default sparse MFMA selector strategy. The SelectedOp can be unsupported.
* @tparam ADataType Data type of matrix A
* @tparam BDataType Data type of matrix B
* @tparam CDataType Data type of the accumulator
* @tparam BlockM Size of the M dimension
* @tparam BlockN Size of the N dimension
* @tparam BlockKTest Size of the K dimension
* @tparam CompilerTarget The compiler target
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
uint32_t BlockM,
uint32_t BlockN,
uint32_t BlockKTest,
typename CompilerTarget>
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
// TODO: c++20 requires(is_target_arch_cdna(CompilerTarget) && is_power_of_two_integer(BlockKTest))
struct SparseMfmaDefaultSelector
{
private:
// Define our candidate MFMA implementation for the current parameters
using CandidateOp = amdgcn_mma<ADataType,
BDataType,
CDataType,
BlockM,
BlockN,
BlockKTest,
DefaultSparseMfmaCtrlFlags,
CompilerTarget,
MmaOpFamily::SPARSE>;
using CandidateTraits = MmaOpTraits<CandidateOp>;
public:
// If the candidate is supported (e.g., a backend implementation exists), then select it.
// Otherwise, fall back to the unsupported pass-through implementation.
using SelectedOp = std::conditional_t<CandidateTraits::IsSupported,
CandidateOp,
amdgcn_mma<ADataType,
BDataType,
CDataType,
BlockM,
BlockN,
BlockKTest,
void,
amdgcn_target<>,
MmaOpFamily::UNDEFINED>>;
};
/**
* @struct MmaDefaultSelector
* @brief Implements the CDNA default MMA selector strategy for sparse MFMA.
* If no supported instruction is found, falls back to an unsupported pass-through implementation.
* @tparam ADataType Data type of matrix A
* @tparam BDataType Data type of matrix B
* @tparam CDataType Data type of the accumulator
* @tparam FragM Size of the M dimension of the fragment to decompose
* @tparam FragN Size of the N dimension of the fragment to decompose
* @tparam FragK Size of the K dimension of the fragment to decompose
* @tparam CompilerTarget The compiler target
* @tparam OpFamily The MMA operation family
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
uint32_t FragM,
uint32_t FragN,
uint32_t FragK,
typename CompilerTarget,
MmaOpFamily OpFamily>
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
// TODO: c++20 requires
struct MmaDefaultSelector<ADataType,
BDataType,
CDataType,
FragM,
FragN,
FragK,
CompilerTarget,
OpFamily,
enable_if_all<std::enable_if_t<is_any_value_of(CompilerTarget::TARGET_ID,
amdgcn_target_id::GFX942,
amdgcn_target_id::GFX950)>,
std::enable_if_t<OpFamily == MmaOpFamily::SPARSE>>>
{
private:
// Provide the default depth-K search strategy for each class of common MFMA shapes.
// Start searching from the largest K dimension MFMA shape down to the smallest.
using CandidateOp16x16 = typename SparseMfmaDefaultSelector<ADataType,
BDataType,
CDataType,
16u,
16u,
32u,
CompilerTarget>::SelectedOp;
using CandidateOp32x32 = typename SparseMfmaDefaultSelector<ADataType,
BDataType,
CDataType,
32u,
32u,
64u,
CompilerTarget>::SelectedOp;
// Default operation triggers pass-through
using DefaultOp = typename SparseMfmaDefaultSelector<ADataType,
BDataType,
CDataType,
1u,
1u,
1u,
CompilerTarget>::SelectedOp;
// Traits for each candidate
using CandidateTraits16x16 = MmaOpTraits<CandidateOp16x16>;
using CandidateTraits32x32 = MmaOpTraits<CandidateOp32x32>;
// Check if each candidate is supported for the given fragment sizes
// For this case, we require the fragment sizes to be multiples of the MFMA shape
static constexpr bool IsSupported16x16 = CandidateTraits16x16::IsSupported &&
(FragM % CandidateTraits16x16::BlockM == 0u) &&
(FragN % CandidateTraits16x16::BlockN == 0u) &&
(FragK % CandidateTraits16x16::BlockK == 0u);
static constexpr bool IsSupported32x32 = CandidateTraits32x32::IsSupported &&
(FragM % CandidateTraits32x32::BlockM == 0u) &&
(FragN % CandidateTraits32x32::BlockN == 0u) &&
(FragK % CandidateTraits32x32::BlockK == 0u);
public:
// Select the largest supported MFMA operation for the given fragment shape
using SelectedOp =
std::conditional_t<IsSupported32x32,
CandidateOp32x32,
std::conditional_t<IsSupported16x16, CandidateOp16x16, DefaultOp>>;
};
} // namespace ck_tile::core::arch::mma

View File

@@ -0,0 +1,108 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp"
namespace ck_tile::core::arch::mma {
/**
* @struct DefaultSparseMfmaCtrlFlags
* @brief Default MFMA sparse flags, select (VGPR[srcC][7..0]) if srcC is
* 16-bit or (VGPR[srcC][15..0]) if srcC is 8-bit.
*/
struct DefaultSparseMfmaCtrlFlags
{
static constexpr SparseCompressionIndex CompressionIndex = SparseCompressionIndex::FIRST;
};
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
#include <concepts>
/**
* @concept SparseMfmaCtrlFlags
* @brief Expresses the interface of required members for each CtrlFlags type
*/
template <typename CtrlFlags>
concept SparseMfmaCtrlFlags = requires(CtrlFlags ctrlFlags) {
// Flag members for sparse MFMA instructions
{ CtrlFlags::CompressionIndex } -> std::convertible_to<SparseCompressionIndex>;
};
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
/**
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for Sparse MFMA (SMFMA) on GFX942, GFX950 targets
*
* This specialization implements the SMFMA instruction for fp16_t A and B
* matrices with structured sparsity, fp32_t accumulator, with 16x16x32 block sizes.
*
* @tparam CtrlFlags Control flags for the Sparse MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsSparseMfmaI CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
struct amdgcn_mma<
fp16_t,
fp16_t,
fp32_t,
16u,
16u,
32u,
CtrlFlags,
CompilerTarget,
MmaOpFamily::SPARSE,
std::enable_if_t<is_any_value_of(
CompilerTarget::TARGET_ID, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950)>>
{
using OpType = MfmaOp;
static constexpr MmaOpFamily OpFamily = MmaOpFamily::SPARSE;
static constexpr index_t ABVecN = 8;
using AVecType = ext_vector_t<fp16_t, ABVecN>;
using BVecType = ext_vector_t<fp16_t, ABVecN>;
using CVecType = ext_vector_t<fp32_t, 4>;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 16;
static constexpr index_t kBNLane = 16;
static constexpr index_t kABKLane = 4;
static constexpr index_t kABKPerLane = 8;
static constexpr index_t kCMLane = 4;
static constexpr index_t kCNLane = 16;
static constexpr index_t kCM0PerLane = 1;
static constexpr index_t kCM1PerLane = 4;
static constexpr index_t kCompressionRatio = 2;
CK_TILE_DEVICE static auto
exec(AVecType& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
{
static constexpr index_t CompressedSize = ABVecN / kCompressionRatio;
using AVecCompressed = ext_vector_t<fp16_t, CompressedSize>;
static_assert(CompressedSize == 4);
// TODO: Compressing A on-the-fly should be OK for now, but we need to validate
// and evaluate changing this to a transform at a higher level.
// aVec not being const can cause problems when running multiple intrinsics.
const int32_t idx = ck_tile::compress_a_impl<fp16_t, CompressedSize>(aVec);
const AVecCompressed a_vec_pruned = {aVec[0], aVec[1], aVec[2], aVec[3]};
using namespace sparse::detail;
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
return {__builtin_amdgcn_smfmac_f32_16x16x32_f16(
a_vec_pruned, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
}
};
} // namespace ck_tile::core::arch::mma

View File

@@ -0,0 +1,68 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
namespace ck_tile::core::arch::mma {
/**
* @enum SparseCompressionIndex
* @brief Indicates which set of sparse-indices within a VGPR starting at srcC
* containing 8-bits (for 16-bit source data) or 16-bits (for 8-bit source data)
* of index information for a lane. \see DefaultSparseMfmaCtrlFlags
*/
enum struct SparseCompressionIndex : int
{
FIRST = 0, // Uses bits [7:0] or [15..0], for 16 and 8 bit data respectively
SECOND = 1, // Uses bits [15:8] or [31:16], for 16 and 8 bit data respectively
THIRD = 2, // Uses bits [23:16]
FOURTH = 3, // Uses bits [31:24]
};
namespace sparse::detail {
/**
* @struct BuiltinParams
* @brief Translates the SparseCompressionIndex to the correct CBSZ and ABID pairs for sparse
* builtins. The actual behavior of the builtin depends on the input data type: 16-bit source data:
* If CBSZ=0, ABID selects one of four 8-bit sets of sparse-indices within a VGPR starting at srcC
* containing 8-bits of index information for a lane. If CBSZ!=0 the very first is selected
* (VGPR[srcC][7..0]).
*
* 8-bit source data:
* If CBSZ=0, ABID selects one of two 16-bit sets of sparse-indices within a VGPR starting at srcC
* containing 16-bits of index information for a lane. If CBSZ!=0; the very first is selected
* (VGPR[srcC][15..0]).
*/
struct BuiltinParams
{
int UseFirstIndex; // CBSZ
int ByteIndexToOverride; // ABID
};
template <SparseCompressionIndex Idx>
static constexpr BuiltinParams getBuiltinParams()
{
BuiltinParams params;
if constexpr(Idx == SparseCompressionIndex::FIRST)
{
params.UseFirstIndex = 1;
params.ByteIndexToOverride = 0;
}
else
{
params.UseFirstIndex = 0;
params.ByteIndexToOverride = static_cast<int>(Idx);
}
return params;
}
} // namespace sparse::detail
} // namespace ck_tile::core::arch::mma
// Include sparse MFMA traits and architecture-specific implementations
#include "ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp"
#include "ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse_transforms.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse_selector.hpp"

View File

@@ -0,0 +1,7 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/arch/mma/sparse/mfma/selector.hpp"
#include "ck_tile/core/arch/mma/sparse/wmma/selector.hpp"

View File

@@ -0,0 +1,48 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
#include "ck_tile/core/arch/mma/mma_transforms.hpp"
namespace ck_tile::core::arch::mma {
/**
* @struct MmaDefaultTransformsSparse
* @brief Implements the default transforms for Sparse
*
* For 2:4 structured sparsity with inline register metadata:
* - ATransform: Pass-through (sparse operands formatted in Exec) TODO!
* - BTransform: Pass-through (sparse operands already formatted)
* - CTransform: Pass-through (input accumulator)
* - DTransform: Pass-through (output accumulator as-is)
*/
struct MmaDefaultTransformsSparse
{
using ATransform = PassThroughTransform;
using BTransform = PassThroughTransform;
using CTransform = PassThroughTransform;
using DTransform = PassThroughTransform;
};
/**
* @class MmaTransformsDefaultSelector
* @brief Specialization for Sparse MFMA transforms
* Provides default transform selection for sparse operations
*
* @tparam MmaOp Sparse MMA operation
* @tparam CompilerTarget The compiler target
*/
// TODO: c++20 template <MmaOpI MmaOp, amdgcn_target CompilerTarget>
// TODO: c++20 requires(is_mma_op_sparse(MmaOp))
template <typename MmaOp, typename CompilerTarget>
struct MmaTransformsDefaultSelector<MmaOp,
CompilerTarget,
std::enable_if_t<MmaOp::OpFamily == MmaOpFamily::SPARSE>>
{
using SelectedTransforms = MmaDefaultTransformsSparse;
};
} // namespace ck_tile::core::arch::mma

View File

@@ -0,0 +1,134 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
#include "ck_tile/core/arch/mma/mma_selector.hpp"
#include "ck_tile/core/arch/mma/mma_traits.hpp"
namespace ck_tile::core::arch::mma {
/**
* @class SparseWmmaDefaultSelector
* @brief Implements a default sparse WMMA selector strategy. The SelectedOp can be unsupported.
* @tparam ADataType Data type of matrix A
* @tparam BDataType Data type of matrix B
* @tparam CDataType Data type of the accumulator
* @tparam BlockM Size of the M dimension
* @tparam BlockN Size of the N dimension
* @tparam BlockKTest Size of the K dimension
* @tparam CompilerTarget The compiler target
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
uint32_t BlockM,
uint32_t BlockN,
uint32_t BlockKTest,
typename CompilerTarget>
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
// TODO: c++20 requires(is_target_arch_rdna(CompilerTarget) && is_power_of_two_integer(BlockKTest))
struct SparseWmmaDefaultSelector
{
private:
// Define our candidate WMMA implementation for the current parameters
using CandidateOp = amdgcn_mma<ADataType,
BDataType,
CDataType,
BlockM,
BlockN,
BlockKTest,
DefaultSparseWmmaCtrlFlags,
CompilerTarget,
MmaOpFamily::SPARSE>;
using CandidateTraits = MmaOpTraits<CandidateOp>;
public:
// If the candidate is supported (e.g., a backend implementation exists), then select it.
// Otherwise, fall back to the unsupported pass-through implementation.
using SelectedOp = std::conditional_t<CandidateTraits::IsSupported,
CandidateOp,
amdgcn_mma<ADataType,
BDataType,
CDataType,
BlockM,
BlockN,
BlockKTest,
void,
amdgcn_target<>,
MmaOpFamily::UNDEFINED>>;
};
/**
* @struct MmaDefaultSelector
* @brief Implements the RDNA default MMA selector strategy for sparse WMMA.
* If no supported instruction is found, falls back to an unsupported pass-through implementation.
* @tparam ADataType Data type of matrix A
* @tparam BDataType Data type of matrix B
* @tparam CDataType Data type of the accumulator
* @tparam FragM Size of the M dimension of the fragment to decompose
* @tparam FragN Size of the N dimension of the fragment to decompose
* @tparam FragK Size of the K dimension of the fragment to decompose
* @tparam CompilerTarget The compiler target
* @tparam OpFamily The MMA operation family
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
uint32_t FragM,
uint32_t FragN,
uint32_t FragK,
typename CompilerTarget,
MmaOpFamily OpFamily>
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
// TODO: c++20 requires
struct MmaDefaultSelector<ADataType,
BDataType,
CDataType,
FragM,
FragN,
FragK,
CompilerTarget,
OpFamily,
enable_if_all<enable_if_target_family_gfx12_t<CompilerTarget>,
std::enable_if_t<OpFamily == MmaOpFamily::SPARSE>>>
{
private:
// Provide the default depth-K search strategy for each class of common WMMA shapes.
// Start searching from the largest K dimension WMMA shape down to the smallest.
using CandidateOp16x16 = typename SparseWmmaDefaultSelector<ADataType,
BDataType,
CDataType,
16u,
16u,
32u,
CompilerTarget>::SelectedOp;
// Default operation triggers pass-through
using DefaultOp = typename SparseWmmaDefaultSelector<ADataType,
BDataType,
CDataType,
1u,
1u,
1u,
CompilerTarget>::SelectedOp;
// Traits for each candidate
using CandidateTraits16x16 = MmaOpTraits<CandidateOp16x16>;
// Check if each candidate is supported for the given fragment sizes
// For this case, we require the fragment sizes to be multiples of the WMMA shape
static constexpr bool IsSupported16x16 = CandidateTraits16x16::IsSupported &&
(FragM % CandidateTraits16x16::BlockM == 0u) &&
(FragN % CandidateTraits16x16::BlockN == 0u) &&
(FragK % CandidateTraits16x16::BlockK == 0u);
public:
// Select the largest supported WMMA operation for the given fragment shape
using SelectedOp = std::conditional_t<IsSupported16x16, CandidateOp16x16, DefaultOp>;
};
} // namespace ck_tile::core::arch::mma

View File

@@ -0,0 +1,73 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
namespace ck_tile::core::arch::mma {
struct DefaultSparseWmmaCtrlFlags
{
};
// TODO: c++20 template <CtrlFlagsSparseWmmaI CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
struct amdgcn_mma<fp16_t,
fp16_t,
fp32_t,
16u,
16u,
32u,
CtrlFlags,
CompilerTarget,
MmaOpFamily::SPARSE,
enable_if_target_family_gfx12_t<CompilerTarget>>
{
using OpType = WmmaOp;
static constexpr MmaOpFamily OpFamily = MmaOpFamily::SPARSE;
static constexpr index_t ABVecN = 16;
using AVecType = ext_vector_t<fp16_t, ABVecN>;
using BVecType = ext_vector_t<fp16_t, ABVecN>;
using CVecType = ext_vector_t<fp32_t, 8>;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 16;
static constexpr index_t kBNLane = 16;
static constexpr index_t kABKLane = 4;
static constexpr index_t kABKPerLane = 8;
static constexpr index_t kCMLane = 4;
static constexpr index_t kCNLane = 16;
static constexpr index_t kCM0PerLane = 1;
static constexpr index_t kCM1PerLane = 4;
static constexpr index_t kCompressionRatio = 2;
CK_TILE_DEVICE static auto
exec(AVecType& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
{
static constexpr index_t CompressedSize = ABVecN / kCompressionRatio;
using AVecCompressed = ext_vector_t<fp16_t, CompressedSize>;
static_assert(CompressedSize == 8);
// TODO: Compressing A on-the-fly should be OK for now, but we need to validate
// and evaluate changing this to a transform at a higher level.
// aVec not being const can cause problems when running multiple intrinsics.
const int32_t idx = ck_tile::compress_a_impl<fp16_t, CompressedSize>(aVec);
const AVecCompressed a_vec_pruned = {
aVec[0], aVec[1], aVec[2], aVec[3], aVec[4], aVec[5], aVec[6], aVec[7]};
return {__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32(a_vec_pruned, bVec, cVec, idx)};
}
};
} // namespace ck_tile::core::arch::mma

View File

@@ -8,6 +8,7 @@
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
#include "ck_tile/core/arch/mma/mma_traits.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
namespace ck_tile::core::arch::mma {
@@ -77,10 +78,12 @@ struct amdgcn_mma<fp16_t,
16u,
CtrlFlags,
CompilerTarget,
enable_if_target_family_gfx11_t<CompilerTarget>>
MmaOpFamily::DENSE,
std::enable_if_t<is_target_family_gfx11<CompilerTarget>()>>
{
// Wmma operation type
using OpType = WmmaOp;
using OpType = WmmaOp;
static constexpr MmaOpFamily OpFamily = MmaOpFamily::DENSE;
// Register types (duplicated input / b32 accum)
using AVecType = ext_vector_t<fp16_t, 16>;

View File

@@ -8,6 +8,7 @@
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
#include "ck_tile/core/arch/mma/mma_traits.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
namespace ck_tile::core::arch::mma {
@@ -37,10 +38,12 @@ struct amdgcn_mma<fp16_t,
16u,
CtrlFlags,
CompilerTarget,
MmaOpFamily::DENSE,
enable_if_target_family_gfx12_t<CompilerTarget>>
{
// Wmma operation type
using OpType = WmmaOp;
using OpType = WmmaOp;
static constexpr MmaOpFamily OpFamily = MmaOpFamily::DENSE;
// Register types
using AVecType = ext_vector_t<fp16_t, 8>;

View File

@@ -46,7 +46,8 @@ struct WmmaDefaultSelector
BlockN,
BlockKTest,
CtrlFlags,
CompilerTarget>;
CompilerTarget,
MmaOpFamily::DENSE>;
using CandidateTraits = MmaOpTraits<CandidateOp>;
@@ -91,8 +92,15 @@ struct WmmaDefaultSelector<ADataType, BDataType, CDataType, BlockM, BlockN, 1u,
using CtrlFlags = DefaultWmmaCtrlFlags<ADataType, BDataType, CDataType>;
// Default unsupported pass-through if no instruction is found
using SelectedOp =
amdgcn_mma<ADataType, BDataType, CDataType, BlockM, BlockN, 1u, CtrlFlags, CompilerTarget>;
using SelectedOp = amdgcn_mma<ADataType,
BDataType,
CDataType,
BlockM,
BlockN,
1u,
CtrlFlags,
CompilerTarget,
MmaOpFamily::DENSE>;
};
/**
@@ -108,6 +116,7 @@ struct WmmaDefaultSelector<ADataType, BDataType, CDataType, BlockM, BlockN, 1u,
* @tparam FragN Size of the N dimension of the fragment to decompose
* @tparam FragK Size of the K dimension of the fragment to decompose
* @tparam CompilerTarget The compiler target
* @tparam OpFamily The MMA operation family
*/
template <typename ADataType,
typename BDataType,
@@ -115,7 +124,8 @@ template <typename ADataType,
uint32_t FragM,
uint32_t FragN,
uint32_t FragK,
typename CompilerTarget>
typename CompilerTarget,
MmaOpFamily OpFamily>
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
// TODO: c++20 requires
struct MmaDefaultSelector<ADataType,
@@ -125,7 +135,9 @@ struct MmaDefaultSelector<ADataType,
FragN,
FragK,
CompilerTarget,
enable_if_target_arch_rdna_t<CompilerTarget>>
OpFamily,
enable_if_all<enable_if_target_arch_rdna_t<CompilerTarget>,
std::enable_if_t<OpFamily == MmaOpFamily::DENSE>>>
{
private:
// Provide the default depth-K search strategy for each class of common WMMA shapes.

View File

@@ -4,8 +4,54 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
namespace ck_tile {
/**
* @brief Compress A vector for 2:4 structured sparsity instruction by moving all non-zero
* elements into lower part of a_vec to half its effective size.
* @param a_vec Vector to be compressed.
* @tparam ADataType The data type of a_vec
* @tparam CompressedSize The target compression size
* @tparam AVec The vector type of a_vec (deduced)
* @return Packed 32bit word containing **CompressedSize** 2bit fields.
* Each field encodes the original position (03) of the corresponding
* nonzero element in the input. If fewer than CompressedSize
* nonzeros are found, remaining fields default to 2 (see below).
*/
template <typename ADataType, index_t CompressedSize, typename AVec>
static CK_TILE_DEVICE int32_t compress_a_impl(AVec& a_vec)
{
// idx holds one 2bit index per output element (total CompressedSize entries).
// It is initialized to the pattern 0b10 for every field. This matches
// what the hardware expects when there are fewer than two nonzero values
// in a 4element group the unused output is treated as coming from slot 2.
// The loop below will clear and set each field as real nonzeros are seen.
int32_t idx = 0;
static_for<0, CompressedSize, 1>{}([&](auto k) { idx |= (2 << (2 * k)); });
static_for<0, CompressedSize / 2, 1>{}([&](auto i) {
ADataType nonzero_elems[2] = {a_vec[i * 4 + 2], a_vec[i * 4 + 3]};
int32_t non_zero_pos = 0;
static_for<0, 3, 1>{}([&](auto j) {
if(a_vec[i * 4 + j] != 0.0f)
{
nonzero_elems[non_zero_pos] = a_vec[i * 4 + j];
// clear the twobit field for this output and insert j
idx &= ~(0b11 << 2 * (i * 2 + non_zero_pos));
idx |= j << 2 * (i * 2 + non_zero_pos);
++non_zero_pos;
}
});
a_vec[i * 2] = nonzero_elems[0];
a_vec[i * 2 + 1] = nonzero_elems[1];
});
return idx;
}
template <typename WarpGemmAttribute_>
struct WarpGemmSmfmacImpl
{
@@ -41,37 +87,10 @@ struct WarpGemmSmfmacImpl
return WarpGemmAttribute_::get_num_of_access();
}
//----------------------------------------------------------------------------------------------
/// @brief Compress A vector for 2:4 structured sparsity instruction by moving all non-zero
/// elements into lower part of a_vec to half its effective size.
///
/// @param a_vec Vector to be compressed.
///
/// @return Four 2-bit indexes of non-zero elements locations
///
template <typename AVec>
CK_TILE_DEVICE int32_t compress_a(AVec& a_vec) const
template <index_t CompressedSize, typename AVec>
CK_TILE_DEVICE int32_t compress_a_vec(AVec& a_vec)
{
int32_t idx = 0b11101110;
static_for<0, 2, 1>{}([&](auto i) {
ADataType nonzero_elems[2] = {a_vec[i * 4 + 2], a_vec[i * 4 + 3]};
int32_t non_zero_pos = 0;
static_for<0, 3, 1>{}([&](auto j) {
if(a_vec[i * 4 + j] != 0.0f)
{
nonzero_elems[non_zero_pos] = a_vec[i * 4 + j];
idx &= ~(0b11 << 2 * (i * 2 + non_zero_pos));
idx |= j << 2 * (i * 2 + non_zero_pos);
++non_zero_pos;
}
});
a_vec[i * 2] = nonzero_elems[0];
a_vec[i * 2 + 1] = nonzero_elems[1];
});
return idx;
return compress_a_impl<ADataType, CompressedSize>(a_vec);
}
template <typename CTensor, typename ATensor, typename BTensor, bool post_nop_ = false>
@@ -84,10 +103,11 @@ struct WarpGemmSmfmacImpl
constexpr auto CompressionRatio = WarpGemmAttribute::kCompressionRatio;
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
using AVecCompressed =
ext_vector_t<ADataType, ATensor::get_thread_buffer_size() / CompressionRatio>;
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
static constexpr index_t CompressedSize =
ATensor::get_thread_buffer_size() / CompressionRatio;
using AVecCompressed = ext_vector_t<ADataType, CompressedSize>;
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
constexpr auto I0 = number<0>{};
@@ -95,8 +115,9 @@ struct WarpGemmSmfmacImpl
const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
const int32_t idx = compress_a(a_vec);
const int32_t idx = compress_a_vec<CompressedSize>(a_vec);
static_assert(CompressedSize == 4);
// @TODO can we simply set a_vec_pruned to a_vec[0:3]?
const AVecCompressed a_vec_pruned = {a_vec[0], a_vec[1], a_vec[2], a_vec[3]};