[CK_Tile] Refactor amdgcn_mma policy structs (#5272)

## Motivation
The point of this MR is to update the intrinsic layout parameters to
simplify them and make them more clear and flexible. Also, a number of
simple refactors were performed to reduce boilerplate and code
duplication.

## Technical Details
In CK Tile and old CK, the full set of information available in the
intrinsic wrappers, for WMMA and MFMA combined, would be something like:

```
// Basic info
using ADataType = void;
using BDataType = void;
using CDataType = void;

using AVecType = ext_vector_t<ADataType, 0>;
using BVecType = ext_vector_t<BDataType, 0>;
using CVecType = ext_vector_t<CDataType, 0>;

// Fragment sizes
static constexpr index_t kM;
static constexpr index_t kN;
static constexpr index_t kK;

// Layout parameters
static constexpr index_t kAMBlock;
static constexpr index_t kBNBlock;

static constexpr index_t kRepeat;
static constexpr index_t kAMLane;
static constexpr index_t kBNLane;
static constexpr index_t kABK0PerLane;
static constexpr index_t kABKLane;
static constexpr index_t kABK1PerLane;

static constexpr index_t kCMLane;
static constexpr index_t kCNLane;
static constexpr index_t kCM0PerLane;
static constexpr index_t kCM1PerLane;

using kABPs2RHssMajor = sequence<2, 1>;
using kABPs2RHssMinor = sequence<1, 0>;
using kABYs2RHsMajor  = sequence<2, 2>;
using kABYs2RHsMinor  = sequence<0, 2>;

using kCPs2RHssMajor = sequence<1, 2>;
using kCPs2RHssMinor = sequence<1, 0>;
using kCYs2RHsMajor  = sequence<1, 1>;
using kCYs2RHsMinor  = sequence<0, 2>;

using kCTPs2RHssMajor = sequence<2, 1>;
using kCTPs2RHssMinor = sequence<1, 0>;
using kCTYs2RHsMajor  = sequence<2, 2>;
using kCTYs2RHsMinor  = sequence<0, 2>;   
 ```
Note that on top of the intrinsic sizes, we have 12 layout parameters. I have reduced this in the new design to:

```
// Basic info
using ADataType = void;
using BDataType = void;
using CDataType = void;

// Fragment sizes
static constexpr index_t kM;
static constexpr index_t kN;
static constexpr index_t kK;

// Layout parameters
static constexpr index_t kABKPerLane; // K2 * K0, Always the same, even
for diff A / B layouts
static constexpr index_t kAKNumAccess; // K2
static constexpr index_t kARepeat; // Used for RDNA3 repeated inputs and
CDNA block hiding.
static constexpr index_t kBKNumAccess; // K2
static constexpr index_t kBRepeat; // Used for RDNA3 repeated inputs and
CDNA block hiding.
static constexpr index_t kCMPerLane;   // M2 * M0
static constexpr index_t kCMNumAccess; // M2

// Derived properties
using AVecType = ext_vector_t<ADataType, 0>;
using BVecType = ext_vector_t<BDataType, 0>;
using CVecType = ext_vector_t<CDataType, 0>;
```

Note that there are now only 7 layout parameters and no more dimensionality orderings. Believe it or not these 7 parameters are more general than the original 12, and can handle intrinsic and mid-level features that are currently awkward in CK Tile, like dealing with AttrNumAccess, different A / B layouts, more general block-hiding (currently very limited in CK tile), and future arch features.

Furthermore, the A, B and C vec types are now derived directly from the layout parameters to ensure internal consistency.

I added a detailed explanation of the new params in terms of register mappings at the top of amgcn_mma.hpp

Other refactorings I did in this MR:

- Make an amdgcn_mma_base struct to drastically reduce code duplication and potential bugs. Should also make auto-generating the amd_gcn specializations much easier.
- Simplify the MmaOpTraits significantly by only including those parameters that are not directly gettable from the MmaOp itself. This removes duplicated variables and simplifies higher level code.
- Remove overloaded "Block" term for intrinsic dimensions, and replace by "Frag" instead. Some spots were already using the term "Frag" for combined intrinsics, in which case I changed that term to "Chunk" instead.
- Remove some tests that had become somewhat pointless (setting variables and then checking their values immediately).

- [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Kiefer van Teutem
2026-03-20 16:07:00 +01:00
committed by GitHub
parent 0a3229ea22
commit e8f9bb0c19
17 changed files with 719 additions and 948 deletions

View File

@@ -14,6 +14,160 @@
namespace ck_tile::core::arch::mma {
/**---------------------------------------------------
* Meaning of amdgcn_mma layout parameters (general)
* ---------------------------------------------------
*
* The fragment (MmaTile) sizes and layout constants in the amdgcn_mma struct describe the mapping
* between intrinsic input / output matrix elements and vector registers (lane x vector_item space).
* Note that we end up having a mapping for A, B and C separately, although those for A and B are
* usually similar if not identical. All mappings can be described as an unmerge operation on one of
* the matrix dims (either K for AB or M for C), followed by remerging of the resulting subdims and
* raw other dim into the Lane and Vector_item dimensions. When considering an unmerge operation on
* a dimension K, we can label the resulting sub-dimensions as K0, K1, and K2, where K0 is the size
* of the fastest changing dimension. K0 is also referred to as "The size of the first unmerge", and
* K1 would be "The size of the second unmerge". There are never more than 2 unmerge operations, and
* unmerge operations may be trivial (unmerge size of 1). Example double unmerge of size {3, 2} of a
* K dimension of size 12:
*
* K K2 K1 K0
* 0 0 0 0
* 1 0 0 1
* 2 0 1 0
* 3 0 1 1
* 4 0 2 0
* 5 0 2 1
* 6 1 0 0
* 7 1 0 1
* 8 1 1 0
* 9 1 1 1
* 10 1 2 0
* 11 1 2 1
*
* Note that K0 = 2 (first unmerge size, fastest changing), K1 = 3 (second unmerge size,
* second-fastest changing), and K2 = 12 / 2 / 3 = 2 (outermost dimension, whatever is left).
*
* If we were to use this unmerge op to describe an A matrix layout in registers, we might have for
* example that L (lane dim) is composed of K1 and M, and V (vector_item dim) is composed of K2 and
* K0. Compactly described, this would be K{3, 2} L{K1M} V{K2K0}, and if the M dimension was 2 we
* would have the following layout (6 lanes, 4 vector items each):
*
* | V0 | V1 | V2 | V3 |
* L0 | M=0 K=0 | M=0 K=1 | M=0 K=6 | M=0 K=7 |
* L1 | M=1 K=0 | M=1 K=1 | M=1 K=6 | M=1 K=7 |
* L2 | M=0 K=2 | M=0 K=3 | M=0 K=8 | M=0 K=9 |
* L3 | M=1 K=2 | M=1 K=3 | M=1 K=8 | M=1 K=9 |
* L4 | M=0 K=4 | M=0 K=5 | M=0 K=10 | M=0 K=11 |
* L5 | M=1 K=4 | M=1 K=5 | M=1 K=10 | M=1 K=11 |
*
* Note that all A matrix elements are now placed in a unique (lane, vector_item). In case a Repeat
* dimension is used, every single matrix element is mapped to multiple (Lane, vector_item)
* locations, usually along the Lane dimension.
*
* Check out TileDistrEncRegMap which can print full forward and backward mapping tables for any
* register mapping (expressed as a tile distribution encoding).
*
* ------------------------------------------
* Individual amdgcn_mma layout parameters
* ------------------------------------------
*
* -- ABKPerLane --
* The number of K dim elements in each lane. Always the same for A and B, even when they have
* different layouts. In terms of unmerge sizes, it's equal to K0 * K2, i.e the product of the sizes
* of the outermost and innermost dimensions after a double K unmerge.
*
* -- A / B NumAccess --
* These two variables describe the size of the outermost dimension if two unmerge operations are
* required for K (so K2). Alternatively it can be described as the number of sets the vector
* dimension, which houses a number of K indices, is split up into. We may be able to actually
* remove the A / B NumAccess from the amdgcn struct, but it sort of depends on how load and store
* tile work and whether we want the mid-level code to always have to know about this. There are
* only two reasons for the A / B NumAccess to ever not be 1, and they are different types of
* reasons:
*
* (logical correctness). Applies to scale MFMA fp8, which due to the index matrix layout does not
* allow arbitrary K perms to simplify layouts. This means the layout can only properly be described
* with a Num Access value of at least 2.
*
* (load / store manipulation). It seems like the load and store tile functions end up looking for
* the size of the smallest unmerged K dimension (K0) to determine how many elements should be
* loaded at a time. Different Num Access values will lead to different load / store behavior, even
* if logically equivalent.
*
* -- A / B Repeat --
* Variable indicating that all matrix values are represented multiple times in the vector
* registers, typically repeating in the lane dimension. This is always equal to the repeat value
* used in Tile Distribution encodings. There are two reasons to have non-trivial (non-1) value
* here: MFMA block-hiding to create oblong "virtual" intrinsics, and RDNA3 input repetition.
*
* -- CMPerLane --
* The number of M dim elements in each lane. In terms of unmerge sizes, it's equal to M0 * M2, i.e
* the product of the sizes of the outermost and innermost dimensions after a double M unmerge.
*
* -- CNumAccess --
* Same as A / B NumAccess but for the M dim (so M2), but the mid-level code doesn't care about this
* and will not try to request a specific value. Absolutely needed for logical correctness of
* register mappings since we can not perform arbitrary M permutations without messing up the A
* layout.
*/
/**
* @class amdgcn_mma_base
* @brief Base class for amdgcn_mma structs to avoid a lot of code duplication. Also puts
* all generic parameter derivations and static asserts in one place. Houses all of the
* amdgcn struct types and variables, except for the exec() function.
*/
template <typename ADataType_,
typename BDataType_,
typename CDataType_,
uint32_t FragM,
uint32_t FragN,
uint32_t FragK,
uint32_t WaveSize_,
index_t kABKPerLane_,
index_t kAKNumAccess_,
index_t kARepeat_,
index_t kBKNumAccess_,
index_t kBRepeat_,
index_t kCMPerLane_,
index_t kCMNumAccess_,
typename OpType_,
MmaOpFamily OpFamily_>
struct amdgcn_mma_base
{
using OpType = OpType_;
static constexpr MmaOpFamily OpFamily = OpFamily_;
// Data types
using ADataType = ADataType_;
using BDataType = BDataType_;
using CDataType = CDataType_;
// Fragment (MmaTile) sizes, check description above.
static constexpr index_t kM = FragM; // M = M2 * M1 * M0
static constexpr index_t kN = FragN;
static constexpr index_t kK = FragK; // K = K2 * K1 * K0
// Layout constants, check description above.
static constexpr index_t kABKPerLane = kABKPerLane_; // K2 * K0
static constexpr index_t kAKNumAccess = kAKNumAccess_; // K2
static constexpr index_t kARepeat = kARepeat_; // RDNA3 repetition and MFMA block-hiding
static constexpr index_t kBKNumAccess = kBKNumAccess_; // K2
static constexpr index_t kBRepeat = kBRepeat_; // RDNA3 repetition and MFMA block-hiding
static constexpr index_t kCMPerLane = kCMPerLane_; // M2 * M0
static constexpr index_t kCMNumAccess = kCMNumAccess_; // M2
// Register types (derived)
static constexpr index_t WaveSize = WaveSize_;
static_assert((kM * kK * kARepeat) % WaveSize == 0);
static_assert((kN * kK * kBRepeat) % WaveSize == 0);
static_assert((kM * kN) % WaveSize == 0);
using AVecType = ext_vector_t<ADataType, kM * kK * kARepeat / WaveSize>;
using BVecType = ext_vector_t<BDataType, kN * kK * kBRepeat / WaveSize>;
using CVecType = ext_vector_t<CDataType, kM * kN / WaveSize>;
};
/**
* @struct Unsupported
* @brief Meta-tag to indicate unsupported amdgcn_mma instance.
@@ -31,23 +185,24 @@ template <typename MmaOp>
concept MmaOpI = requires(MmaOp op) {
// Requires an op context
typename MmaOp::OpType;
typename MmaOp::OpFamily;
// Captures types for inputs / outputs to mma function
typename MmaOp::ADataType;
typename MmaOp::BDataType;
typename MmaOp::CDataType;
typename MmaOp::AVecType;
typename MmaOp::BVecType;
typename MmaOp::CVecType;
// Captures CK-specific layout properties
{ MmaOp::kAMBlock } -> std::convertible_to<unsigned int>;
{ MmaOp::kBNBlock } -> std::convertible_to<unsigned int>;
{ MmaOp::kAMLane } -> std::convertible_to<unsigned int>;
{ MmaOp::kBNLane } -> std::convertible_to<unsigned int>;
{ MmaOp::kABKLane } -> std::convertible_to<unsigned int>;
{ MmaOp::kABKPerLane } -> std::convertible_to<unsigned int>;
{ MmaOp::kCMLane } -> std::convertible_to<unsigned int>;
{ MmaOp::kCNLane } -> std::convertible_to<unsigned int>;
{ MmaOp::kCM0PerLane } -> std::convertible_to<unsigned int>;
{ MmaOp::kCM1PerLane } -> std::convertible_to<unsigned int>;
{ MmaOp::kAKNumAccess } -> std::convertible_to<unsigned int>;
{ MmaOp::kARepeat } -> std::convertible_to<unsigned int>;
{ MmaOp::kBKNumAccess } -> std::convertible_to<unsigned int>;
{ MmaOp::kBRepeat } -> std::convertible_to<unsigned int>;
{ MmaOp::kCMPerLane } -> std::convertible_to<unsigned int>;
{ MmaOp::kCMNumAccess } -> std::convertible_to<unsigned int>;
// Static exec function
{
@@ -69,52 +224,40 @@ concept MmaOpI = requires(MmaOp op) {
* @tparam ADataType Datatype of input A
* @tparam BDataType Datatype of input B
* @tparam CDataType Datatype of accumulator
* @tparam BlockM M-dimension of mma block
* @tparam BlockN N-dimension of mma block
* @tparam BlockK K-dimension of mma block
* @tparam FragM M-dimension of mma intrinsic (MmaTile)
* @tparam FragN N-dimension of mma intrinsic (MmaTile)
* @tparam FragK K-dimension of mma intrinsic (MmaTile)
* @tparam CtrlFlags Control flags for mma operation
* @tparam CompilerTarget The current compiler target
* @tparam OpFamily_ The type of operation (dense, sparse, scale, etc.)
* @tparam Enabler SFINAE enabler
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
uint32_t BlockM,
uint32_t BlockN,
uint32_t BlockK,
uint32_t FragM,
uint32_t FragN,
uint32_t FragK,
typename CtrlFlags,
typename CompilerTarget,
MmaOpFamily OpFamily_,
typename Enabler = void>
struct amdgcn_mma
// clang-format off
// | A B C DataTypes |MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma : amdgcn_mma_base<fp32_t, fp32_t, fp32_t, 1u, 1u, 1u, 1u, 1, 1, 1, 1, 1, 1, 1, Unsupported, MmaOpFamily::UNDEFINED>
// clang-format on
{
// The base instance is unsupported because there is no __builtin to wrap.
using OpType = Unsupported;
static constexpr MmaOpFamily OpFamily = MmaOpFamily::UNDEFINED;
// Interface types for A, B, C vectors types
using AVecType = ext_vector_t<ADataType, 1>;
using BVecType = ext_vector_t<BDataType, 1>;
using CVecType = ext_vector_t<CDataType, 1>;
// Layout constants - default to 0
static constexpr index_t kAMBlock = 0;
static constexpr index_t kBNBlock = 0;
static constexpr index_t kAMLane = 0;
static constexpr index_t kBNLane = 0;
static constexpr index_t kABKLane = 0;
static constexpr index_t kABKPerLane = 0;
static constexpr index_t kCMLane = 0;
static constexpr index_t kCNLane = 0;
static constexpr index_t kCM0PerLane = 0;
static constexpr index_t kCM1PerLane = 0;
// This is a default pass-through implementation that doesn't do anything practical.
CK_TILE_DEVICE static CVecType const&
exec(AVecType const& regsA, BVecType const& regsB, CVecType const& regsC)
{
// Prints once across all thread blocks and threads.
static __device__ int printed = 0;
if(threadIdx.x == 0 && atomicCAS(&printed, 0, 1) == 0)
{
printf("[WARNING] Running amdgcn_mma dummy exec function!\n");
}
ignore(regsA, regsB);
return regsC; // No-op, just return C
}

View File

@@ -25,7 +25,7 @@ namespace ck_tile::core::arch::mma {
* @brief Specialization of amdgcn_mma for MFMA on GFX9 targets
*
* This specialization implements the MFMA instruction for fp16_t A and B
* matrices, and fp32_t accumulator matrix, with 16x16x16 block sizes.
* matrices, and fp32_t accumulator matrix, with 16x16x16 fragment sizes.
*
* @tparam CtrlFlags Control flags for the MFMA operation
* @tparam CompilerTarget Current compiler target
@@ -33,40 +33,12 @@ namespace ck_tile::core::arch::mma {
// TODO: c++20 template <CtrlFlagsGfx9I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
struct amdgcn_mma<fp16_t,
fp16_t,
fp32_t,
16u,
16u,
16u,
CtrlFlags,
CompilerTarget,
MmaOpFamily::DENSE,
enable_if_target_family_gfx9_t<CompilerTarget>>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx9_t<CompilerTarget>>
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, 64u, 4, 1, 1, 1, 1, 4, 1, MfmaOp, MmaOpFamily::DENSE>
// clang-format on
{
// Mfma operation type
using OpType = MfmaOp;
static constexpr MmaOpFamily OpFamily = MmaOpFamily::DENSE;
// Register types
using AVecType = ext_vector_t<fp16_t, 4>;
using BVecType = ext_vector_t<fp16_t, 4>;
using CVecType = ext_vector_t<fp32_t, 4>;
// Layout constants
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 16;
static constexpr index_t kBNLane = 16;
static constexpr index_t kABKLane = 4;
static constexpr index_t kABKPerLane = 4;
static constexpr index_t kCMLane = 4;
static constexpr index_t kCNLane = 16;
static constexpr index_t kCM0PerLane = 1;
static constexpr index_t kCM1PerLane = 4;
CK_TILE_DEVICE static auto
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
{
@@ -84,7 +56,7 @@ struct amdgcn_mma<fp16_t,
* @brief Specialization of amdgcn_mma for MFMA on GFX950 targets
*
* This specialization implements the MFMA instruction for fp16_t A and B
* matrices, and fp32_t accumulator matrix, with 16x16x32 block sizes.
* matrices, and fp32_t accumulator matrix, with 16x16x32 fragment sizes.
*
* @tparam CtrlFlags Control flags for the MFMA operation
* @tparam CompilerTarget Current compiler target
@@ -92,39 +64,12 @@ struct amdgcn_mma<fp16_t,
// TODO: c++20 template <CtrlFlagsGfx9I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
struct amdgcn_mma<fp16_t,
fp16_t,
fp32_t,
16u,
16u,
32u,
CtrlFlags,
CompilerTarget,
MmaOpFamily::DENSE,
enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, 64u, 8, 1, 1, 1, 1, 4, 1, MfmaOp, MmaOpFamily::DENSE>
// clang-format on
{
using OpType = MfmaOp;
static constexpr MmaOpFamily OpFamily = MmaOpFamily::DENSE;
// Packed register types
using AVecType = ext_vector_t<fp16_t, 8>;
using BVecType = ext_vector_t<fp16_t, 8>;
using CVecType = ext_vector_t<fp32_t, 4>;
// Layout constants
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 16;
static constexpr index_t kBNLane = 16;
static constexpr index_t kABKLane = 8;
static constexpr index_t kABKPerLane = 8;
static constexpr index_t kCMLane = 4;
static constexpr index_t kCNLane = 16;
static constexpr index_t kCM0PerLane = 1;
static constexpr index_t kCM1PerLane = 4;
CK_TILE_DEVICE static auto
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
{

View File

@@ -18,28 +18,27 @@ namespace ck_tile::core::arch::mma {
* @class MfmaDefaultSelector
* @brief Implements a default MFMA selector strategy for gfx9 target architectures.
* This implements the K dimension search strategy to find the largest supported MFMA
* instruction for the given M/N block sizes and datatypes.
* If no supported instruction is found, falls back to an unsupported pass-through
implementation.
* @tparam ADataType Data type of matrix A
* @tparam BDataType Data type of matrix B
* @tparam CDataType Data type of the accumulator
* @tparam BlockM Block M dimension size
* @tparam BlockN Block N dimension size
* @tparam BlockKTest Current Block K dimension size to test
* instruction for the given M/N WaveTile sizes and datatypes.
* If no supported instruction is found, falls back to an unsupported pass-through implementation.
* @tparam ADataType Data type of matrix A
* @tparam BDataType Data type of matrix B
* @tparam CDataType Data type of the accumulator
* @tparam WaveTileM WaveTile M dimension size
* @tparam WaveTileN WaveTile N dimension size
* @tparam WaveTileKTest Current WaveTile K dimension size to test
* @tparam CompilerTarget The compiler target
* @note Here we assume that BlockKTest is always a power-of-two integer.
* The search strategy starts from a maximum BlockKTest size down to 1u by halving
* @note Here we assume that WaveTileKTest is always a power-of-two integer.
* The search strategy starts from a maximum WaveTileKTest size down to 1u by halving
* each time.
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
uint32_t BlockM,
uint32_t BlockN,
uint32_t BlockKTest,
uint32_t WaveTileM,
uint32_t WaveTileN,
uint32_t WaveTileKTest,
typename CompilerTarget> // TODO: c++20 amdgcn_target_arch_id CompilerTarget>
// TODO: c++20 requires(is_gfx9_arch_id(CompilerTarget) && is_power_of_two_integer(BlockKTest))
// TODO: c++20 requires(is_gfx9_arch_id(CompilerTarget) && is_power_of_two_integer(WaveTileKTest))
struct MfmaDefaultSelector
{
private:
@@ -48,26 +47,25 @@ struct MfmaDefaultSelector
amdgcn_mma<ADataType,
BDataType,
CDataType,
BlockM,
BlockN,
BlockKTest,
WaveTileM,
WaveTileN,
WaveTileKTest,
DefaultMfmaCtrlFlags, // By default, let's assume no special flags for MFMA
CompilerTarget,
MmaOpFamily::DENSE>;
using CandidateTraits = MmaOpTraits<CandidateOp>;
public:
// If the candidate is supported (e.g., a backend implementation exists), then select it.
// Otherwise, test another smaller BlockK. If no existing implementations, we will get BlockK=0u
// and fall back to the unsupported pass-through implementation.
using SelectedOp = std::conditional_t<CandidateTraits::IsSupported,
// Otherwise, test another smaller WaveTileK. If no existing implementations, we will get
// WaveTileK=0u and fall back to the unsupported pass-through implementation.
using SelectedOp = std::conditional_t<MmaOpTraits<CandidateOp>::IsSupported,
CandidateOp,
typename MfmaDefaultSelector<ADataType,
BDataType,
CDataType,
BlockM,
BlockN,
BlockKTest / 2u,
WaveTileM,
WaveTileN,
WaveTileKTest / 2u,
CompilerTarget>::SelectedOp>;
};
@@ -75,28 +73,34 @@ struct MfmaDefaultSelector
* @struct MfmaDefaultSelector
* @brief Implements the base case for the default MFMA selector when no supported instruction is
* found.
* @tparam ADataType Data type of matrix A
* @tparam BDataType Data type of matrix B
* @tparam CDataType Data type of the accumulator
* @tparam BlockM Block M dimension size
* @tparam BlockN Block N dimension size
* @tparam ADataType Data type of matrix A
* @tparam BDataType Data type of matrix B
* @tparam CDataType Data type of the accumulator
* @tparam WaveTileM WaveTile M dimension size
* @tparam WaveTileN WaveTile N dimension size
* @tparam CompilerTarget The compiler target
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
uint32_t BlockM,
uint32_t BlockN,
uint32_t WaveTileM,
uint32_t WaveTileN,
typename CompilerTarget> // TODO: c++20 amdgcn_target_arch_id CompilerTarget>
struct MfmaDefaultSelector<ADataType, BDataType, CDataType, BlockM, BlockN, 1u, CompilerTarget>
struct MfmaDefaultSelector<ADataType,
BDataType,
CDataType,
WaveTileM,
WaveTileN,
1u,
CompilerTarget>
{
// Default unsupported pass-through if no instruction is found
using SelectedOp =
amdgcn_mma<ADataType,
BDataType,
CDataType,
BlockM,
BlockN,
WaveTileM,
WaveTileN,
1u,
DefaultMfmaCtrlFlags, // By default, let's assume no special flags for MFMA
CompilerTarget,
@@ -106,32 +110,32 @@ struct MfmaDefaultSelector<ADataType, BDataType, CDataType, BlockM, BlockN, 1u,
/**
* @struct MmaDefaultSelector
* @brief Implements the gfx9 default MMA selector strategy for wave-wise MMA decomposition.
* This implements the M/N block size search strategy to find the largest supported MFMA
* This implements the M/N WaveTile size search strategy to find the largest supported MFMA
* instruction for the given datatypes.
* If no supported instruction is found, falls back to an unsupported pass-through implementation.
* @tparam ADataType Data type of matrix A
* @tparam BDataType Data type of matrix B
* @tparam CDataType Data type of the accumulator
* @tparam FragM Size of the M dimension of the fragment to decompose
* @tparam FragN Size of the N dimension of the fragment to decompose
* @tparam FragK Size of the K dimension of the fragment to decompose
* @tparam ADataType Data type of matrix A
* @tparam BDataType Data type of matrix B
* @tparam CDataType Data type of the accumulator
* @tparam WaveTileM Size of the M dimension of the WaveTile to decompose
* @tparam WaveTileN Size of the N dimension of the WaveTile to decompose
* @tparam WaveTileK Size of the K dimension of the WaveTile to decompose
* @tparam CompilerTarget The compiler target
* @tparam OpFamily The MMA operation family
* @tparam OpFamily The MMA operation family
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
uint32_t FragM,
uint32_t FragN,
uint32_t FragK,
uint32_t WaveTileM,
uint32_t WaveTileN,
uint32_t WaveTileK,
typename CompilerTarget,
MmaOpFamily OpFamily> // TODO: c++20 amdgcn_target_arch_id CompilerTarget>
struct MmaDefaultSelector<ADataType,
BDataType,
CDataType,
FragM,
FragN,
FragK,
WaveTileM,
WaveTileN,
WaveTileK,
CompilerTarget,
OpFamily,
enable_if_all<enable_if_target_family_gfx9_t<CompilerTarget>,
@@ -163,27 +167,20 @@ struct MmaDefaultSelector<ADataType,
typename MfmaDefaultSelector<ADataType, BDataType, CDataType, 1u, 1u, 1u, CompilerTarget>::
SelectedOp;
// Traits for each candidate
using CandidateTraits4x4 = MmaOpTraits<CandidateOp4x4>;
using CandidateTraits16x16 = MmaOpTraits<CandidateOp16x16>;
using CandidateTraits32x32 = MmaOpTraits<CandidateOp32x32>;
// Check if each candidate is supported for the given fragment sizes
// For this case, we require the fragment sizes to be multiples of the MFMA shape
// Check if each candidate is supported for the given WaveTile sizes
// For this case, we require the WaveTile sizes to be multiples of the MFMA shape
static constexpr bool IsSupported4x4 =
CandidateTraits4x4::IsSupported && (FragM % CandidateTraits4x4::BlockM == 0u) &&
(FragN % CandidateTraits4x4::BlockN == 0u) && (FragK % CandidateTraits4x4::BlockK == 0u);
static constexpr bool IsSupported16x16 = CandidateTraits16x16::IsSupported &&
(FragM % CandidateTraits16x16::BlockM == 0u) &&
(FragN % CandidateTraits16x16::BlockN == 0u) &&
(FragK % CandidateTraits16x16::BlockK == 0u);
static constexpr bool IsSupported32x32 = CandidateTraits32x32::IsSupported &&
(FragM % CandidateTraits32x32::BlockM == 0u) &&
(FragN % CandidateTraits32x32::BlockN == 0u) &&
(FragK % CandidateTraits32x32::BlockK == 0u);
MmaOpTraits<CandidateOp4x4>::IsSupported && (WaveTileM % CandidateOp4x4::kM == 0u) &&
(WaveTileN % CandidateOp4x4::kN == 0u) && (WaveTileK % CandidateOp4x4::kK == 0u);
static constexpr bool IsSupported16x16 =
MmaOpTraits<CandidateOp16x16>::IsSupported && (WaveTileM % CandidateOp16x16::kM == 0u) &&
(WaveTileN % CandidateOp16x16::kN == 0u) && (WaveTileK % CandidateOp16x16::kK == 0u);
static constexpr bool IsSupported32x32 =
MmaOpTraits<CandidateOp32x32>::IsSupported && (WaveTileM % CandidateOp32x32::kM == 0u) &&
(WaveTileN % CandidateOp32x32::kN == 0u) && (WaveTileK % CandidateOp32x32::kK == 0u);
public:
// Select the largest supported MFMA operation for the given fragment shape
// Select the largest supported MFMA operation for the given WaveTile shape
using SelectedOp = std::conditional_t<
IsSupported32x32,
CandidateOp32x32,

View File

@@ -6,7 +6,6 @@
#include "amdgcn_mma.hpp"
#include "mma_selector.hpp"
#include "mma_traits.hpp"
#include "mma_transforms.hpp"
#include "mfma/mfma.hpp"
@@ -19,44 +18,42 @@ namespace ck_tile::core::arch::mma {
*/
enum struct MmaAccumPolicy
{
// Decomposition and accumulation in row-major block order
// Decomposition and accumulation in row-major fragment order
ROW_MAJOR,
// Decomposition and accumulation in col-major block order
// Decomposition and accumulation in col-major fragment order
COL_MAJOR
};
/**
* @class Mma
* @brief Driver for the wave-tile Mma operation. Given a backend block-wise MmaOp implementation
* (e.g., mfma or wmma), this class performs block-wise decomposition to matrix-multiply input
* fragments of (A: FragM x FragK) x (B: FragK x FragN) and accumulates results into output fragment
* (C: FragM x FragN).
* @tparam ADataType Data type of input fragment A
* @tparam BDataType Data type of input fragment B
* @tparam CDataType Data type of input/output fragment C (accumulator)
* @tparam FragM Mma fragment M dimension
* @tparam FragN Mma fragment K dimension
* @tparam FragK Mma fragment M dimension
* @tparam AccumPolicy The block order of the accumulation registers (row major or col major block
* order)
* @brief Driver for the wave-tile Mma operation. Given a backend MmaOp implementation
* (e.g., mfma or wmma), this class performs fragment-wise (MmaTile) decomposition to
* matrix-multiply input WaveTiles of (A: WaveTileM x WaveTileK) x (B: WaveTileK x WaveTileN) and
* accumulates results into output WaveTile (C: WaveTileM x WaveTileN).
* @tparam ADataType Data type of input WaveTile A
* @tparam BDataType Data type of input WaveTile B
* @tparam CDataType Data type of input/output WaveTile C (accumulator)
* @tparam WaveTileM Mma WaveTile M dimension
* @tparam WaveTileN Mma WaveTile K dimension
* @tparam WaveTileK Mma WaveTile M dimension
* @tparam AccumPolicy The fragment order of the accum. registers (row or col major frag order)
* @tparam CompilerTarget The compiler target
* @tparam MmaOp The backend wrapper class that will perform block-wise mma op (e.g., mfma or
* wmma)
* @tparam MmaTransforms The set of transforms to be applied to input/output fragments
* @tparam MmaOp Backend wrapper class that will perform the mma op (e.g., mfma or wmma)
* @tparam MmaTransforms The set of transforms to be applied to input/output WaveTiles
* @par This is an example of an Mma decomposition driver class that can be used in a wave-tile
* context. Given a fragment size, we can decompose the fragment into smaller block-wise mma ops
* context. Given a WaveTile size, we can decompose the WaveTile into smaller mma op fragments
* that are natively supported by the hardware (e.g., mfma or wmma). The class also supports
* applying transforms to the input/output fragments as needed (e.g., layout conversions, data type
* applying transforms to the input/output frags as needed (e.g., layout conversions, data type
* conversions, etc.). We may also specify the accumulation order (row-major or col-major) for the
* output fragment. This is a powerful example of how to build a flexible and reusable mma driver
* output WaveTile. This is a powerful example of how to build a flexible and reusable mma driver
* that can adapt to different hardware capabilities and requirements.
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
uint32_t FragM,
uint32_t FragN,
uint32_t FragK,
uint32_t WaveTileM,
uint32_t WaveTileN,
uint32_t WaveTileK,
MmaOpFamily OpFamily,
MmaAccumPolicy AccumPolicy = MmaAccumPolicy::ROW_MAJOR,
typename CompilerTarget =
@@ -67,39 +64,37 @@ template <typename ADataType,
// MmaDefaultSelector<ADataType,
BDataType,
CDataType,
FragM,
FragN,
FragK,
WaveTileM,
WaveTileN,
WaveTileK,
CompilerTarget,
OpFamily>::SelectedOp,
typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms =
typename MmaTransformsDefaultSelector<MmaOp, CompilerTarget>::SelectedTransforms>
struct WaveWiseMma
{
using FragWiseMmaOp = MmaOp;
using BlockWiseMmaOp = MmaOp;
using BlockWiseMmaOpTraits = MmaOpTraits<BlockWiseMmaOp>;
// Fragment dimensions
constexpr static uint32_t FragM = MmaOp::kM;
constexpr static uint32_t FragN = MmaOp::kN;
constexpr static uint32_t FragK = MmaOp::kK;
// Block dimensions
constexpr static uint32_t BlockM = BlockWiseMmaOpTraits::BlockM;
constexpr static uint32_t BlockN = BlockWiseMmaOpTraits::BlockN;
constexpr static uint32_t BlockK = BlockWiseMmaOpTraits::BlockK;
// Fragment counts for decomposition
constexpr static uint32_t FragsM = WaveTileM / FragM;
constexpr static uint32_t FragsN = WaveTileN / FragN;
constexpr static uint32_t FragsK = WaveTileK / FragK;
constexpr static uint32_t FragsC = FragsM * FragsN;
// Block counts for decomposition
constexpr static uint32_t BlocksM = FragM / BlockM;
constexpr static uint32_t BlocksN = FragN / BlockN;
constexpr static uint32_t BlocksK = FragK / BlockK;
constexpr static uint32_t BlocksC = BlocksM * BlocksN;
// Vector types for packed registers in each fragment
using AVecType = typename MmaOp::AVecType;
using BVecType = typename MmaOp::BVecType;
using CVecType = typename MmaOp::CVecType;
// Vector types for packed registers in each block
using AVecType = typename BlockWiseMmaOpTraits::AVecType;
using BVecType = typename BlockWiseMmaOpTraits::BVecType;
using CVecType = typename BlockWiseMmaOpTraits::CVecType;
// Buffer types for fragments
using ABufferType = AVecType[BlocksM][BlocksK];
using BBufferType = BVecType[BlocksN][BlocksK];
using CBufferType = CVecType[BlocksM][BlocksN];
// Buffer types for WaveTiles
using ABufferType = AVecType[FragsM][FragsK];
using BBufferType = BVecType[FragsN][FragsK];
using CBufferType = CVecType[FragsM][FragsN];
// Transforms
using ATransform = typename MmaTransforms::ATransform;
@@ -108,20 +103,20 @@ struct WaveWiseMma
using DTransform = typename MmaTransforms::DTransform;
// Sanity checks
static_assert(FragM >= BlockM, "FragM must be larger than BlockM");
static_assert(FragN >= BlockN, "FragN must be larger than BlockN");
static_assert(FragK >= BlockK, "FragK must be larger than BlockK");
static_assert(FragM % BlockM == 0u, "FragM must be a multiple of BlockM");
static_assert(FragN % BlockN == 0u, "FragN must be a multiple of BlockN");
static_assert(FragK % BlockK == 0u, "FragK must be a multiple of BlockK");
static_assert(WaveTileM >= FragM, "WaveTileM must be larger than FragM");
static_assert(WaveTileN >= FragN, "WaveTileN must be larger than FragN");
static_assert(WaveTileK >= FragK, "WaveTileK must be larger than FragK");
static_assert(WaveTileM % FragM == 0u, "WaveTileM must be a multiple of FragM");
static_assert(WaveTileN % FragN == 0u, "WaveTileN must be a multiple of FragN");
static_assert(WaveTileK % FragK == 0u, "WaveTileK must be a multiple of FragK");
private:
template <typename DstT, typename SrcT>
CK_TILE_DEVICE static auto formatBuffer(SrcT const& inputBuffer)
{
// TODO: Implement formatting logic as needed.
// This is intended to convert input fragments to the native vector types
// required by the BlockWiseMma operation for iteration
// This is intended to convert input WaveTiles to the native vector types
// required by the FragWiseMma operation for iteration
static_assert(sizeof(DstT) == sizeof(SrcT), "Size mismatch in formatBuffer");
return reinterpret_cast<DstT const&>(inputBuffer);
}
@@ -130,16 +125,16 @@ struct WaveWiseMma
CK_TILE_DEVICE static auto formatBuffer(SrcT& inputBuffer)
{
// TODO: Implement formatting logic as needed.
// This is intended to convert input fragments to the native vector types
// required by the BlockWiseMma operation for iteration
// This is intended to convert input WaveTiles to the native vector types
// required by the FragWiseMma operation for iteration
static_assert(sizeof(DstT) == sizeof(SrcT), "Size mismatch in formatBuffer");
return reinterpret_cast<DstT&>(inputBuffer);
}
/*! @brief Execute Mma in row-major accumulation order.
* @tparam VecTA The input fragment A vector type
* @tparam VecTB The input fragment B vector type
* @tparam VecTC The input/output fragment C vector type
* @tparam VecTA The input WaveTile A vector type
* @tparam VecTB The input WaveTile B vector type
* @tparam VecTC The input/output WaveTile C vector type
*/
template <typename VecTA, typename VecTB, typename VecTC>
CK_TILE_DEVICE static decltype(auto) exec_col_major(VecTA&& a, VecTB&& b, VecTC&& accum)
@@ -153,35 +148,35 @@ struct WaveWiseMma
auto b_frag = formatBuffer<BBufferType>(BTransform::exec(b));
auto c_frag = formatBuffer<CBufferType>(CTransform::exec(accum));
// "Col-major" accumulation over the M-dimension blocks first.
// Pseudo code here, but we would basically iterate over the blocks in col-major order
for(uint32_t bn = 0u; bn < BlocksN; ++bn)
// "Col-major" accumulation over the M-dimension fragments first.
// Pseudo code here, but we would basically iterate over the fragments in col-major order
for(uint32_t bn = 0u; bn < FragsN; ++bn)
{
for(uint32_t bm = 0u; bm < BlocksM; ++bm)
for(uint32_t bm = 0u; bm < FragsM; ++bm)
{
for(uint32_t bk = 0u; bk < BlocksK; ++bk)
for(uint32_t bk = 0u; bk < FragsK; ++bk)
{
c_frag[bm][bn] =
BlockWiseMmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]);
FragWiseMmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]);
}
}
}
// Convert native vector results back to the output fragment format
// Convert native vector results back to the output WaveTile format
// and then return after we apply the final output transform.
return DTransform::exec(formatBuffer<std::decay_t<VecTC>>(c_frag));
}
/*! @brief Execute Mma in row-major accumulation order.
* @tparam VecTA The input fragment A vector type
* @tparam VecTB The input fragment B vector type
* @tparam VecTC The input/output fragment C vector type
* @tparam VecTA The input WaveTile A vector type
* @tparam VecTB The input WaveTile B vector type
* @tparam VecTC The input/output WaveTile C vector type
*/
template <typename VecTA, typename VecTB, typename VecTC>
CK_TILE_DEVICE static decltype(auto) exec_row_major(VecTA&& a, VecTB&& b, VecTC&& accum)
{
// We implement an example wave-tile pipeline here.
// First, we apply the necessary transforms to the input fragments,
// First, we apply the necessary transforms to the input WaveTiles,
// then we convert the result into buffers of native vector formats
// that we can easily index. Native vector formats are necessary inputs
// to the given MmaOp exec function.
@@ -189,32 +184,32 @@ struct WaveWiseMma
auto b_frag = formatBuffer<BBufferType>(BTransform::exec(b));
auto c_frag = formatBuffer<CBufferType>(CTransform::exec(accum));
// "Row-major" accumulation over the N-dimension blocks first.
// Pseudo code here, but we would basically iterate over the blocks in row-major order.
// We also have to ensure that the incoming vector fragments are converted to native vector
// types before passing to the BlockWiseMma exec function.
for(uint32_t bm = 0u; bm < BlocksM; ++bm)
// "Row-major" accumulation over the N-dimension fragments first.
// Pseudo code here, but we would basically iterate over the fragments in row-major order.
// We also have to ensure that the incoming vector WaveTiles are converted to native vector
// types before passing to the FragWiseMma exec function.
for(uint32_t bm = 0u; bm < FragsM; ++bm)
{
for(uint32_t bn = 0u; bn < BlocksN; ++bn)
for(uint32_t bn = 0u; bn < FragsN; ++bn)
{
for(uint32_t bk = 0u; bk < BlocksK; ++bk)
for(uint32_t bk = 0u; bk < FragsK; ++bk)
{
c_frag[bm][bn] =
BlockWiseMmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]);
FragWiseMmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]);
}
}
}
// Convert native vector results back to the output fragment format
// Convert native vector results back to the output WaveTile format
// and then return after we apply the final output transform.
return DTransform::exec(formatBuffer<std::decay_t<VecTC>>(c_frag));
}
public:
/*! @brief Forward to Mma operation with specified accumulation order.
* @tparam VecTA The input fragment A vector type
* @tparam VecTB The input fragment B vector type
* @tparam VecTC The input/output fragment C vector type
* @tparam VecTA The input WaveTile A vector type
* @tparam VecTB The input WaveTile B vector type
* @tparam VecTC The input/output WaveTile C vector type
*/
template <typename VecTA, typename VecTB, typename VecTC>
CK_TILE_DEVICE static decltype(auto) exec(VecTA&& a, VecTB&& b, VecTC&& accum)

View File

@@ -9,32 +9,31 @@ namespace ck_tile::core::arch::mma {
/**
* @class MmaDefaultSelector
* @brief Implements a default mma selector strategy for the current target architecture.
* This is simply intended as a default selection strategy for mma instruction operations.
* Given the particular datatypes and Fragment dimensions, the selector will attempt to
* select the instruction with the largest K dimension that is supported on the current target
* architecture.
* @brief Implements a default mma selector strategy for the current target architecture. This is
* simply intended as a default selection strategy for mma instruction operations. Given the
* particular datatypes and WaveTile dimensions, the selector will attempt to select the instruction
* with the largest K dimension that is supported on the current target architecture.
* @tparam ADataType Data type of matrix A
* @tparam BDataType Data type of matrix B
* @tparam CDataType Data type of the accumulator
* @tparam FragM Fragment M dimension
* @tparam FragN Fragment N dimension
* @tparam FragK Fragment K dimension
* @tparam WaveTileM WaveTile M dimension
* @tparam WaveTileN WaveTile N dimension
* @tparam WaveTileK WaveTile K dimension
* @tparam CompilerTarget The compiler target
* @tparam OpFamily The MMA operation family
* @tparam Enable SFINAE enabler
* @note Here we distinguish that Fragment MNK sizes from Block MNK sizes used in the actual MMA
* operation. Fragment sizes correspond to the overall tile size being computed, while Block sizes
* correspond to the size of the individual MMA instructions being used to compute the overall in
* block-wise. The Fragment sizes must be multiples of the Block sizes and in general larger than or
* equal to the Block sizes.
* @note Here we distinguish that WaveTile MNK sizes from Fragment MNK sizes used in the actual MMA
* operation. WaveTile sizes correspond to the overall tile size being computed, while Fragment
* sizes correspond to the size of the individual MMA instructions being used to compute the overall
* in fragment-wise. The WaveTile sizes must be multiples of the Fragment sizes and in general
* larger than or equal to the Fragment sizes.
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
uint32_t FragM,
uint32_t FragN,
uint32_t FragK,
uint32_t WaveTileM,
uint32_t WaveTileN,
uint32_t WaveTileK,
typename CompilerTarget,
MmaOpFamily OpFamily,
typename Enable = void>
@@ -46,9 +45,9 @@ struct MmaDefaultSelector
using SelectedOp = amdgcn_mma<ADataType,
BDataType,
CDataType,
FragM,
FragN,
FragK,
WaveTileM,
WaveTileN,
WaveTileK,
void,
amdgcn_target<>,
MmaOpFamily::UNDEFINED>;

View File

@@ -44,122 +44,64 @@ struct is_mma_op_supported<MmaOp,
template <typename MmaOp>
static constexpr bool is_mma_op_supported_v = is_mma_op_supported<MmaOp>::value;
/**
* @class MmaOpParams
* @brief Reflects the template parameters of a given MmaOp
* @tparam MmaOp The matrix multiply-accumulate operation type to check
*/
// TODO: c++20 template <MmaOpI MmaOp>
template <typename MmaOp>
struct MmaOpParams;
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
#include <concepts>
template <typename T>
struct MmaOpTraits;
/**
* @concept MmaOpParamsI
* @brief Expresses the required members for each MmaOp
*/
template <typename MmaOpParams>
concept MmaOpParamsI = requires(MmaOpParams op) {
// Capture template parameters
typename MmaOpParams::ADataType;
typename MmaOpParams::BDataType;
typename MmaOpParams::CDataType;
typename MmaOpParams::CtrlFlags;
{ MmaOpParams::BlockM } -> std::convertible_to<unsigned int>;
{ MmaOpParams::BlockN } -> std::convertible_to<unsigned int>;
{ MmaOpParams::BlockK } -> std::convertible_to<unsigned int>;
{ MmaOpParams::GfxTargetId } -> std::convertible_to<amdgcn_target_arch_id>;
{ MmaOpParams::Family } -> std::convertible_to<MmaOpFamily>;
};
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
/**
* @struct MmaOpParams
* @brief Reflects the template parameters of a given MmaOp
* @struct MmaOpTraits
* @brief Gives additional traits and unexposed template parameters of a given MmaOp
* @tparam ADataType_ Data type of matrix A
* @tparam BDataType_ Data type of matrix B
* @tparam CDataType_ Data type of the accumulator
* @tparam BlockM_ Size of the M dimension
* @tparam BlockN_ Size of the N dimension
* @tparam BlockK_ Size of the K dimension
* @tparam FragM_ Size of the M dimension
* @tparam FragN_ Size of the N dimension
* @tparam FragK_ Size of the K dimension
* @tparam CtrlFlags_ Control flags for the MMA operation
* @tparam CompilerTarget_ The compiler target
*/
template <typename ADataType_,
typename BDataType_,
typename CDataType_,
uint32_t BlockM_,
uint32_t BlockN_,
uint32_t BlockK_,
uint32_t FragM_,
uint32_t FragN_,
uint32_t FragK_,
typename CtrlFlags_,
typename CompilerTarget_,
MmaOpFamily OpFamily_>
// TODO: c++20 amdgcn_target_arch_id CompilerTarget_>
struct MmaOpParams<amdgcn_mma<ADataType_,
struct MmaOpTraits<amdgcn_mma<ADataType_,
BDataType_,
CDataType_,
BlockM_,
BlockN_,
BlockK_,
FragM_,
FragN_,
FragK_,
CtrlFlags_,
CompilerTarget_,
OpFamily_>>
{
// Capture incoming template parameters
using ADataType = ADataType_;
using BDataType = BDataType_;
using CDataType = CDataType_;
static constexpr uint32_t BlockM = BlockM_;
static constexpr uint32_t BlockN = BlockN_;
static constexpr uint32_t BlockK = BlockK_;
using CtrlFlags = CtrlFlags_;
using CompilerTarget = CompilerTarget_;
static constexpr auto MmaOpFamily = OpFamily_;
using MmaOp = amdgcn_mma<ADataType_,
BDataType_,
CDataType_,
FragM_,
FragN_,
FragK_,
CtrlFlags_,
CompilerTarget_,
OpFamily_>;
// Capture incoming template parameters not already in amdgcn
using CtrlFlags = CtrlFlags_;
using CompilerTarget = CompilerTarget_;
// TODO c++20static constexpr amdgcn_target_arch_id GfxTargetId = CompilerTarget_;
};
/**
* @class MmaOpTraits
* @brief Reflects the template parameters and static members of a given MmaOp.
* @tparam MmaOp The matrix multiply-accumulate operation
*/
template <typename MmaOp>
// TODO: c++20 template <MmaOpI MmaOp>
// TODO: c++20 requires MmaOpParamsI<MmaOpParams<MmaOp>>
struct MmaOpTraits : public MmaOpParams<MmaOp>
{
// Capture internal MmaOp static members
using OpType = typename MmaOp::OpType;
using AVecType = typename MmaOp::AVecType;
using BVecType = typename MmaOp::BVecType;
using CVecType = typename MmaOp::CVecType;
static constexpr MmaOpFamily OpFamily = MmaOp::OpFamily;
// Capture layout parameters
static constexpr index_t kAMBlock = MmaOp::kAMBlock;
static constexpr index_t kBNBlock = MmaOp::kBNBlock;
static constexpr index_t kAMLane = MmaOp::kAMLane;
static constexpr index_t kBNLane = MmaOp::kBNLane;
static constexpr index_t kABKLane = MmaOp::kABKLane;
static constexpr index_t kABKPerLane = MmaOp::kABKPerLane;
static constexpr index_t kCMLane = MmaOp::kCMLane;
static constexpr index_t kCNLane = MmaOp::kCNLane;
static constexpr index_t kCM0PerLane = MmaOp::kCM0PerLane;
static constexpr index_t kCM1PerLane = MmaOp::kCM1PerLane;
// Additional traits to identify the type of MmaOp at compile time
constexpr static bool IsMfma = is_mma_op_mfma_v<MmaOp>;
constexpr static bool IsWmma = is_mma_op_wmma_v<MmaOp>;
constexpr static bool IsDense = OpFamily == MmaOpFamily::DENSE;
constexpr static bool IsSparse = OpFamily == MmaOpFamily::SPARSE;
constexpr static bool IsScale = OpFamily == MmaOpFamily::SCALE;
constexpr static bool IsDense = OpFamily_ == MmaOpFamily::DENSE;
constexpr static bool IsSparse = OpFamily_ == MmaOpFamily::SPARSE;
constexpr static bool IsScale = OpFamily_ == MmaOpFamily::SCALE;
constexpr static bool IsSupported =
is_mma_op_supported_v<MmaOp> && OpFamily != MmaOpFamily::UNDEFINED;
is_mma_op_supported_v<MmaOp> && OpFamily_ != MmaOpFamily::UNDEFINED;
};
} // namespace ck_tile::core::arch::mma

View File

@@ -14,23 +14,24 @@ namespace ck_tile::core::arch::mma {
/**
* @class SparseMfmaDefaultSelector
* @brief Implements a default sparse MFMA selector strategy. The SelectedOp can be unsupported.
* @tparam ADataType Data type of matrix A
* @tparam BDataType Data type of matrix B
* @tparam CDataType Data type of the accumulator
* @tparam BlockM Size of the M dimension
* @tparam BlockN Size of the N dimension
* @tparam BlockKTest Size of the K dimension
* @tparam ADataType Data type of matrix A
* @tparam BDataType Data type of matrix B
* @tparam CDataType Data type of the accumulator
* @tparam WaveTileM Size of the M dimension
* @tparam WaveTileN Size of the N dimension
* @tparam WaveTileKTest Size of the K dimension
* @tparam CompilerTarget The compiler target
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
uint32_t BlockM,
uint32_t BlockN,
uint32_t BlockKTest,
uint32_t WaveTileM,
uint32_t WaveTileN,
uint32_t WaveTileKTest,
typename CompilerTarget>
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
// TODO: c++20 requires(is_target_arch_cdna(CompilerTarget) && is_power_of_two_integer(BlockKTest))
// TODO: c++20 requires(is_target_arch_cdna(CompilerTarget) &&
// is_power_of_two_integer(WaveTileKTest))
struct SparseMfmaDefaultSelector
{
private:
@@ -38,26 +39,24 @@ struct SparseMfmaDefaultSelector
using CandidateOp = amdgcn_mma<ADataType,
BDataType,
CDataType,
BlockM,
BlockN,
BlockKTest,
WaveTileM,
WaveTileN,
WaveTileKTest,
DefaultSparseMfmaCtrlFlags,
CompilerTarget,
MmaOpFamily::SPARSE>;
using CandidateTraits = MmaOpTraits<CandidateOp>;
public:
// If the candidate is supported (e.g., a backend implementation exists), then select it.
// Otherwise, fall back to the unsupported pass-through implementation.
using SelectedOp = std::conditional_t<CandidateTraits::IsSupported,
using SelectedOp = std::conditional_t<MmaOpTraits<CandidateOp>::IsSupported,
CandidateOp,
amdgcn_mma<ADataType,
BDataType,
CDataType,
BlockM,
BlockN,
BlockKTest,
WaveTileM,
WaveTileN,
WaveTileKTest,
void,
amdgcn_target<>,
MmaOpFamily::UNDEFINED>>;
@@ -67,21 +66,21 @@ struct SparseMfmaDefaultSelector
* @struct MmaDefaultSelector
* @brief Implements the CDNA default MMA selector strategy for sparse MFMA.
* If no supported instruction is found, falls back to an unsupported pass-through implementation.
* @tparam ADataType Data type of matrix A
* @tparam BDataType Data type of matrix B
* @tparam CDataType Data type of the accumulator
* @tparam FragM Size of the M dimension of the fragment to decompose
* @tparam FragN Size of the N dimension of the fragment to decompose
* @tparam FragK Size of the K dimension of the fragment to decompose
* @tparam ADataType Data type of matrix A
* @tparam BDataType Data type of matrix B
* @tparam CDataType Data type of the accumulator
* @tparam WaveTileM Size of the M dimension of the WaveTile to decompose
* @tparam WaveTileN Size of the N dimension of the WaveTile to decompose
* @tparam WaveTileK Size of the K dimension of the WaveTile to decompose
* @tparam CompilerTarget The compiler target
* @tparam OpFamily The MMA operation family
* @tparam OpFamily The MMA operation family
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
uint32_t FragM,
uint32_t FragN,
uint32_t FragK,
uint32_t WaveTileM,
uint32_t WaveTileN,
uint32_t WaveTileK,
typename CompilerTarget,
MmaOpFamily OpFamily>
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
@@ -89,9 +88,9 @@ template <typename ADataType,
struct MmaDefaultSelector<ADataType,
BDataType,
CDataType,
FragM,
FragN,
FragK,
WaveTileM,
WaveTileN,
WaveTileK,
CompilerTarget,
OpFamily,
enable_if_all<std::enable_if_t<is_any_value_of(CompilerTarget::TARGET_ID,
@@ -126,23 +125,17 @@ struct MmaDefaultSelector<ADataType,
1u,
CompilerTarget>::SelectedOp;
// Traits for each candidate
using CandidateTraits16x16 = MmaOpTraits<CandidateOp16x16>;
using CandidateTraits32x32 = MmaOpTraits<CandidateOp32x32>;
// Check if each candidate is supported for the given fragment sizes
// For this case, we require the fragment sizes to be multiples of the MFMA shape
static constexpr bool IsSupported16x16 = CandidateTraits16x16::IsSupported &&
(FragM % CandidateTraits16x16::BlockM == 0u) &&
(FragN % CandidateTraits16x16::BlockN == 0u) &&
(FragK % CandidateTraits16x16::BlockK == 0u);
static constexpr bool IsSupported32x32 = CandidateTraits32x32::IsSupported &&
(FragM % CandidateTraits32x32::BlockM == 0u) &&
(FragN % CandidateTraits32x32::BlockN == 0u) &&
(FragK % CandidateTraits32x32::BlockK == 0u);
// Check if each candidate is supported for the given WaveTile sizes
// For this case, we require the WaveTile sizes to be multiples of the MFMA shape
static constexpr bool IsSupported16x16 =
MmaOpTraits<CandidateOp16x16>::IsSupported && (WaveTileM % CandidateOp16x16::kM == 0u) &&
(WaveTileN % CandidateOp16x16::kN == 0u) && (WaveTileK % CandidateOp16x16::kK == 0u);
static constexpr bool IsSupported32x32 =
MmaOpTraits<CandidateOp32x32>::IsSupported && (WaveTileM % CandidateOp32x32::kM == 0u) &&
(WaveTileN % CandidateOp32x32::kN == 0u) && (WaveTileK % CandidateOp32x32::kK == 0u);
public:
// Select the largest supported MFMA operation for the given fragment shape
// Select the largest supported MFMA operation for the given WaveTile shape
using SelectedOp =
std::conditional_t<IsSupported32x32,
CandidateOp32x32,

View File

@@ -16,7 +16,7 @@ namespace ck_tile::core::arch::mma {
* @brief Specialization of amdgcn_mma for Sparse MFMA (SMFMA) on GFX942, GFX950 targets
*
* This specialization implements the SMFMA instruction for fp16_t A and B
* matrices with structured sparsity, fp32_t accumulator, with 16x16x32 block sizes.
* matrices with structured sparsity, fp32_t accumulator, with 16x16x32 fragment sizes.
*
* @tparam CtrlFlags Control flags for the Sparse MFMA operation
* @tparam CompilerTarget Current compiler target
@@ -24,53 +24,25 @@ namespace ck_tile::core::arch::mma {
// TODO: c++20 template <CtrlFlagsSparseMfmaI CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
struct amdgcn_mma<
fp16_t,
fp16_t,
fp32_t,
16u,
16u,
32u,
CtrlFlags,
CompilerTarget,
MmaOpFamily::SPARSE,
std::enable_if_t<is_any_value_of(
CompilerTarget::TARGET_ID, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950)>>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, std::enable_if_t<is_any_value_of(CompilerTarget::TARGET_ID, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950)>>
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, 64u, 8, 1, 1, 1, 1, 4, 1, MfmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
using OpType = MfmaOp;
static constexpr MmaOpFamily OpFamily = MmaOpFamily::SPARSE;
static constexpr index_t ABVecN = 8;
using AVecType = ext_vector_t<fp16_t, ABVecN>;
using BVecType = ext_vector_t<fp16_t, ABVecN>;
using CVecType = ext_vector_t<fp32_t, 4>;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 16;
static constexpr index_t kBNLane = 16;
static constexpr index_t kABKLane = 4;
static constexpr index_t kABKPerLane = 8;
static constexpr index_t kCMLane = 4;
static constexpr index_t kCNLane = 16;
static constexpr index_t kCM0PerLane = 1;
static constexpr index_t kCM1PerLane = 4;
static constexpr index_t kCompressionRatio = 2;
CK_TILE_DEVICE static auto
exec(AVecType& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
{
static constexpr index_t CompressedSize = ABVecN / kCompressionRatio;
using AVecCompressed = ext_vector_t<fp16_t, CompressedSize>;
static constexpr index_t ABVecN = vector_traits<AVecType>::vector_size;
static constexpr index_t kCompressionRatio = 2;
static constexpr index_t CompressedSize = ABVecN / kCompressionRatio;
using AVecCompressed = ext_vector_t<fp16_t, CompressedSize>;
static_assert(CompressedSize == 4);
// TODO: Compressing A on-the-fly should be OK for now, but we need to validate
// and evaluate changing this to a transform at a higher level.
// aVec not being const can cause problems when running multiple intrinsics.
const int32_t idx = ck_tile::compress_a_impl<fp16_t, CompressedSize>(aVec);
const uint32_t idx = ck_tile::compress_a_impl<fp16_t, CompressedSize>(aVec);
const AVecCompressed a_vec_pruned = {aVec[0], aVec[1], aVec[2], aVec[3]};

View File

@@ -13,23 +13,24 @@ namespace ck_tile::core::arch::mma {
/**
* @class SparseWmmaDefaultSelector
* @brief Implements a default sparse WMMA selector strategy. The SelectedOp can be unsupported.
* @tparam ADataType Data type of matrix A
* @tparam BDataType Data type of matrix B
* @tparam CDataType Data type of the accumulator
* @tparam BlockM Size of the M dimension
* @tparam BlockN Size of the N dimension
* @tparam BlockKTest Size of the K dimension
* @tparam ADataType Data type of matrix A
* @tparam BDataType Data type of matrix B
* @tparam CDataType Data type of the accumulator
* @tparam WaveTileM Size of the M dimension
* @tparam WaveTileN Size of the N dimension
* @tparam WaveTileKTest Size of the K dimension
* @tparam CompilerTarget The compiler target
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
uint32_t BlockM,
uint32_t BlockN,
uint32_t BlockKTest,
uint32_t WaveTileM,
uint32_t WaveTileN,
uint32_t WaveTileKTest,
typename CompilerTarget>
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
// TODO: c++20 requires(is_target_arch_rdna(CompilerTarget) && is_power_of_two_integer(BlockKTest))
// TODO: c++20 requires(is_target_arch_rdna(CompilerTarget) &&
// is_power_of_two_integer(WaveTileKTest))
struct SparseWmmaDefaultSelector
{
private:
@@ -37,9 +38,9 @@ struct SparseWmmaDefaultSelector
using CandidateOp = amdgcn_mma<ADataType,
BDataType,
CDataType,
BlockM,
BlockN,
BlockKTest,
WaveTileM,
WaveTileN,
WaveTileKTest,
DefaultSparseWmmaCtrlFlags,
CompilerTarget,
MmaOpFamily::SPARSE>;
@@ -54,9 +55,9 @@ struct SparseWmmaDefaultSelector
amdgcn_mma<ADataType,
BDataType,
CDataType,
BlockM,
BlockN,
BlockKTest,
WaveTileM,
WaveTileN,
WaveTileKTest,
void,
amdgcn_target<>,
MmaOpFamily::UNDEFINED>>;
@@ -66,21 +67,21 @@ struct SparseWmmaDefaultSelector
* @struct MmaDefaultSelector
* @brief Implements the RDNA default MMA selector strategy for sparse WMMA.
* If no supported instruction is found, falls back to an unsupported pass-through implementation.
* @tparam ADataType Data type of matrix A
* @tparam BDataType Data type of matrix B
* @tparam CDataType Data type of the accumulator
* @tparam FragM Size of the M dimension of the fragment to decompose
* @tparam FragN Size of the N dimension of the fragment to decompose
* @tparam FragK Size of the K dimension of the fragment to decompose
* @tparam ADataType Data type of matrix A
* @tparam BDataType Data type of matrix B
* @tparam CDataType Data type of the accumulator
* @tparam WaveTileM Size of the M dimension of the WaveTile to decompose
* @tparam WaveTileN Size of the N dimension of the WaveTile to decompose
* @tparam WaveTileK Size of the K dimension of the WaveTile to decompose
* @tparam CompilerTarget The compiler target
* @tparam OpFamily The MMA operation family
* @tparam OpFamily The MMA operation family
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
uint32_t FragM,
uint32_t FragN,
uint32_t FragK,
uint32_t WaveTileM,
uint32_t WaveTileN,
uint32_t WaveTileK,
typename CompilerTarget,
MmaOpFamily OpFamily>
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
@@ -88,9 +89,9 @@ template <typename ADataType,
struct MmaDefaultSelector<ADataType,
BDataType,
CDataType,
FragM,
FragN,
FragK,
WaveTileM,
WaveTileN,
WaveTileK,
CompilerTarget,
OpFamily,
enable_if_all<enable_if_target_family_gfx12_t<CompilerTarget>,
@@ -116,18 +117,14 @@ struct MmaDefaultSelector<ADataType,
1u,
CompilerTarget>::SelectedOp;
// Traits for each candidate
using CandidateTraits16x16 = MmaOpTraits<CandidateOp16x16>;
// Check if each candidate is supported for the given fragment sizes
// For this case, we require the fragment sizes to be multiples of the WMMA shape
static constexpr bool IsSupported16x16 = CandidateTraits16x16::IsSupported &&
(FragM % CandidateTraits16x16::BlockM == 0u) &&
(FragN % CandidateTraits16x16::BlockN == 0u) &&
(FragK % CandidateTraits16x16::BlockK == 0u);
// Check if each candidate is supported for the given WaveTile sizes
// For this case, we require the WaveTile sizes to be multiples of the WMMA shape
static constexpr bool IsSupported16x16 =
MmaOpTraits<CandidateOp16x16>::IsSupported && (WaveTileM % CandidateOp16x16::kM == 0u) &&
(WaveTileN % CandidateOp16x16::kN == 0u) && (WaveTileK % CandidateOp16x16::kK == 0u);
public:
// Select the largest supported WMMA operation for the given fragment shape
// Select the largest supported WMMA operation for the given WaveTile shape
using SelectedOp = std::conditional_t<IsSupported16x16, CandidateOp16x16, DefaultOp>;
};

View File

@@ -15,51 +15,24 @@ namespace ck_tile::core::arch::mma {
// TODO: c++20 template <CtrlFlagsSparseWmmaI CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
struct amdgcn_mma<fp16_t,
fp16_t,
fp32_t,
16u,
16u,
32u,
CtrlFlags,
CompilerTarget,
MmaOpFamily::SPARSE,
enable_if_target_family_gfx12_t<CompilerTarget>>
// clang-format off
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, 32u, 16, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
using OpType = WmmaOp;
static constexpr MmaOpFamily OpFamily = MmaOpFamily::SPARSE;
static constexpr index_t ABVecN = 16;
using AVecType = ext_vector_t<fp16_t, ABVecN>;
using BVecType = ext_vector_t<fp16_t, ABVecN>;
using CVecType = ext_vector_t<fp32_t, 8>;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 16;
static constexpr index_t kBNLane = 16;
static constexpr index_t kABKLane = 4;
static constexpr index_t kABKPerLane = 8;
static constexpr index_t kCMLane = 4;
static constexpr index_t kCNLane = 16;
static constexpr index_t kCM0PerLane = 1;
static constexpr index_t kCM1PerLane = 4;
static constexpr index_t kCompressionRatio = 2;
CK_TILE_DEVICE static auto
exec(AVecType& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
{
static constexpr index_t CompressedSize = ABVecN / kCompressionRatio;
using AVecCompressed = ext_vector_t<fp16_t, CompressedSize>;
static constexpr index_t ABVecN = vector_traits<AVecType>::vector_size;
static constexpr index_t kCompressionRatio = 2;
static constexpr index_t CompressedSize = ABVecN / kCompressionRatio;
using AVecCompressed = ext_vector_t<fp16_t, CompressedSize>;
static_assert(CompressedSize == 8);
// TODO: Compressing A on-the-fly should be OK for now, but we need to validate
// and evaluate changing this to a transform at a higher level.
// aVec not being const can cause problems when running multiple intrinsics.
const int32_t idx = ::ck_tile::compress_a_impl<fp16_t, CompressedSize>(aVec);
const uint32_t idx = ck_tile::compress_a_impl<fp16_t, CompressedSize>(aVec);
const AVecCompressed a_vec_pruned = {
aVec[0], aVec[1], aVec[2], aVec[3], aVec[4], aVec[5], aVec[6], aVec[7]};

View File

@@ -70,38 +70,12 @@ struct DefaultWmmaCtrlFlags
// TODO: c++20 template <CtrlFlagsGfx11I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
struct amdgcn_mma<fp16_t,
fp16_t,
fp32_t,
16u,
16u,
16u,
CtrlFlags,
CompilerTarget,
MmaOpFamily::DENSE,
std::enable_if_t<is_target_family_gfx11<CompilerTarget>()>>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, std::enable_if_t<is_target_family_gfx11<CompilerTarget>()>>
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, 32u, 16, 1, 2, 1, 2, 8, 8, WmmaOp, MmaOpFamily::DENSE>
// clang-format on
{
// Wmma operation type
using OpType = WmmaOp;
static constexpr MmaOpFamily OpFamily = MmaOpFamily::DENSE;
// Register types (duplicated input / b32 accum)
using AVecType = ext_vector_t<fp16_t, 16>;
using BVecType = ext_vector_t<fp16_t, 16>;
using CVecType = ext_vector_t<fp32_t, 8>;
// Layout constants
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 16;
static constexpr index_t kBNLane = 16;
static constexpr index_t kABKLane = 8;
static constexpr index_t kABKPerLane = 8;
static constexpr index_t kCMLane = 2;
static constexpr index_t kCNLane = 2;
static constexpr index_t kCM0PerLane = 4;
static constexpr index_t kCM1PerLane = 1;
CK_TILE_DEVICE static auto
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
{

View File

@@ -30,38 +30,12 @@ namespace ck_tile::core::arch::mma {
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
struct amdgcn_mma<fp16_t,
fp16_t,
fp32_t,
16u,
16u,
16u,
CtrlFlags,
CompilerTarget,
MmaOpFamily::DENSE,
enable_if_target_family_gfx12_t<CompilerTarget>>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, 32u, 8, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::DENSE>
// clang-format on
{
// Wmma operation type
using OpType = WmmaOp;
static constexpr MmaOpFamily OpFamily = MmaOpFamily::DENSE;
// Register types
using AVecType = ext_vector_t<fp16_t, 8>;
using BVecType = ext_vector_t<fp16_t, 8>;
using CVecType = ext_vector_t<fp32_t, 8>;
// Layout constants
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 16;
static constexpr index_t kBNLane = 16;
static constexpr index_t kABKLane = 8;
static constexpr index_t kABKPerLane = 8;
static constexpr index_t kCMLane = 2;
static constexpr index_t kCNLane = 2;
static constexpr index_t kCM0PerLane = 4;
static constexpr index_t kCM1PerLane = 1;
CK_TILE_DEVICE static auto
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
{

View File

@@ -14,24 +14,24 @@ namespace ck_tile::core::arch::mma {
* @class WmmaDefaultSelector
* @brief Implements a default WMMA selector strategy for gfx11/12 target architectures.
* This implements the K dimension search strategy to find the largest supported WMMA
* instruction for the given M/N block sizes and datatypes.
* @tparam ADataType Data type of matrix A
* @tparam BDataType Data type of matrix B
* @tparam CDataType Data type of the accumulator
* @tparam BlockM Size of the M dimension
* @tparam BlockN Size of the N dimension
* @tparam BlockKTest Size of the K dimension
* instruction for the given M/N WaveTile sizes and datatypes.
* @tparam ADataType Data type of matrix A
* @tparam BDataType Data type of matrix B
* @tparam CDataType Data type of the accumulator
* @tparam WaveTileM Size of the M dimension
* @tparam WaveTileN Size of the N dimension
* @tparam WaveTileKTest Size of the K dimension
* @tparam CompilerTarget The compiler target
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
uint32_t BlockM,
uint32_t BlockN,
uint32_t BlockKTest,
uint32_t WaveTileM,
uint32_t WaveTileN,
uint32_t WaveTileKTest,
typename CompilerTarget>
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
// TODO: c++20 requires(is_rdna_arch_id(CompilerTarget) && is_power_of_two_integer(BlockKTest))
// TODO: c++20 requires(is_rdna_arch_id(CompilerTarget) && is_power_of_two_integer(WaveTileKTest))
struct WmmaDefaultSelector
{
private:
@@ -42,27 +42,25 @@ struct WmmaDefaultSelector
using CandidateOp = amdgcn_mma<ADataType,
BDataType,
CDataType,
BlockM,
BlockN,
BlockKTest,
WaveTileM,
WaveTileN,
WaveTileKTest,
CtrlFlags,
CompilerTarget,
MmaOpFamily::DENSE>;
using CandidateTraits = MmaOpTraits<CandidateOp>;
public:
// If the candidate is supported (e.g., a backend implementation exists), then select it.
// Otherwise, test another smaller BlockK. If no existing implementations, we will get BlockK=0u
// and fall back to the unsupported pass-through implementation.
using SelectedOp = std::conditional_t<CandidateTraits::IsSupported,
// Otherwise, test another smaller WaveTileK. If no existing implementations, we will get
// WaveTileK=0u and fall back to the unsupported pass-through implementation.
using SelectedOp = std::conditional_t<MmaOpTraits<CandidateOp>::IsSupported,
CandidateOp,
typename WmmaDefaultSelector<ADataType,
BDataType,
CDataType,
BlockM,
BlockN,
BlockKTest / 2u,
WaveTileM,
WaveTileN,
WaveTileKTest / 2u,
CompilerTarget>::SelectedOp>;
};
@@ -72,21 +70,27 @@ struct WmmaDefaultSelector
* This implements the K dimension == 1, which is the base case for the recursive K dimension
* search. If no supported instruction is found, falls back to an unsupported pass-through
* implementation.
* @tparam ADataType Data type of matrix A
* @tparam BDataType Data type of matrix B
* @tparam CDataType Data type of the accumulator
* @tparam BlockM Size of the M dimension
* @tparam BlockN Size of the N dimension
* @tparam ADataType Data type of matrix A
* @tparam BDataType Data type of matrix B
* @tparam CDataType Data type of the accumulator
* @tparam WaveTileM Size of the M dimension
* @tparam WaveTileN Size of the N dimension
* @tparam CompilerTarget The compiler target
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
uint32_t BlockM,
uint32_t BlockN,
uint32_t WaveTileM,
uint32_t WaveTileN,
typename CompilerTarget>
// TODO: c++20 amdgcn_target_arch_id GfxTargetId>
struct WmmaDefaultSelector<ADataType, BDataType, CDataType, BlockM, BlockN, 1u, CompilerTarget>
struct WmmaDefaultSelector<ADataType,
BDataType,
CDataType,
WaveTileM,
WaveTileN,
1u,
CompilerTarget>
{
// By default, let's assume no special flags for WMMA
using CtrlFlags = DefaultWmmaCtrlFlags<ADataType, BDataType, CDataType>;
@@ -95,8 +99,8 @@ struct WmmaDefaultSelector<ADataType, BDataType, CDataType, BlockM, BlockN, 1u,
using SelectedOp = amdgcn_mma<ADataType,
BDataType,
CDataType,
BlockM,
BlockN,
WaveTileM,
WaveTileN,
1u,
CtrlFlags,
CompilerTarget,
@@ -106,24 +110,24 @@ struct WmmaDefaultSelector<ADataType, BDataType, CDataType, BlockM, BlockN, 1u,
/**
* @struct MmaDefaultSelector
* @brief Implements the rdna default MMA selector strategy for wave-wise MMA decomposition.
* This implements the M/N block size search strategy to find the largest supported WMMA
* This implements the M/N WaveTile size search strategy to find the largest supported WMMA
* instruction for the given datatypes.
* If no supported instruction is found, falls back to an unsupported pass-through implementation.
* @tparam ADataType Data type of matrix A
* @tparam BDataType Data type of matrix B
* @tparam CDataType Data type of the accumulator
* @tparam FragM Size of the M dimension of the fragment to decompose
* @tparam FragN Size of the N dimension of the fragment to decompose
* @tparam FragK Size of the K dimension of the fragment to decompose
* @tparam ADataType Data type of matrix A
* @tparam BDataType Data type of matrix B
* @tparam CDataType Data type of the accumulator
* @tparam WaveTileM Size of the M dimension of the WaveTile to decompose
* @tparam WaveTileN Size of the N dimension of the WaveTile to decompose
* @tparam WaveTileK Size of the K dimension of the WaveTile to decompose
* @tparam CompilerTarget The compiler target
* @tparam OpFamily The MMA operation family
* @tparam OpFamily The MMA operation family
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
uint32_t FragM,
uint32_t FragN,
uint32_t FragK,
uint32_t WaveTileM,
uint32_t WaveTileN,
uint32_t WaveTileK,
typename CompilerTarget,
MmaOpFamily OpFamily>
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
@@ -131,9 +135,9 @@ template <typename ADataType,
struct MmaDefaultSelector<ADataType,
BDataType,
CDataType,
FragM,
FragN,
FragK,
WaveTileM,
WaveTileN,
WaveTileK,
CompilerTarget,
OpFamily,
enable_if_all<enable_if_target_arch_rdna_t<CompilerTarget>,
@@ -155,18 +159,14 @@ struct MmaDefaultSelector<ADataType,
typename WmmaDefaultSelector<ADataType, BDataType, CDataType, 1u, 1u, 1u, CompilerTarget>::
SelectedOp;
// Traits for each candidate
using CandidateTraits16x16 = MmaOpTraits<CandidateOp16x16>;
// Check if each candidate is supported for the given fragment sizes
// For this case, we require the fragment sizes to be multiples of the WMMA shape
static constexpr bool IsSupported16x16 = CandidateTraits16x16::IsSupported &&
(FragM % CandidateTraits16x16::BlockM == 0u) &&
(FragN % CandidateTraits16x16::BlockN == 0u) &&
(FragK % CandidateTraits16x16::BlockK == 0u);
// Check if each candidate is supported for the given WaveTile sizes
// For this case, we require the WaveTile sizes to be multiples of the WMMA shape
static constexpr bool IsSupported16x16 =
MmaOpTraits<CandidateOp16x16>::IsSupported && (WaveTileM % CandidateOp16x16::kM == 0u) &&
(WaveTileN % CandidateOp16x16::kN == 0u) && (WaveTileK % CandidateOp16x16::kK == 0u);
public:
// Select the largest supported WMMA operation for the given fragment shape
// Select the largest supported WMMA operation for the given WaveTile shape
using SelectedOp = std::conditional_t<IsSupported16x16, CandidateOp16x16, DefaultOp>;
};