mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK_Tile] Refactor amdgcn_mma policy structs (#5272)
## Motivation The point of this MR is to update the intrinsic layout parameters to simplify them and make them more clear and flexible. Also, a number of simple refactors were performed to reduce boilerplate and code duplication. ## Technical Details In CK Tile and old CK, the full set of information available in the intrinsic wrappers, for WMMA and MFMA combined, would be something like: ``` // Basic info using ADataType = void; using BDataType = void; using CDataType = void; using AVecType = ext_vector_t<ADataType, 0>; using BVecType = ext_vector_t<BDataType, 0>; using CVecType = ext_vector_t<CDataType, 0>; // Fragment sizes static constexpr index_t kM; static constexpr index_t kN; static constexpr index_t kK; // Layout parameters static constexpr index_t kAMBlock; static constexpr index_t kBNBlock; static constexpr index_t kRepeat; static constexpr index_t kAMLane; static constexpr index_t kBNLane; static constexpr index_t kABK0PerLane; static constexpr index_t kABKLane; static constexpr index_t kABK1PerLane; static constexpr index_t kCMLane; static constexpr index_t kCNLane; static constexpr index_t kCM0PerLane; static constexpr index_t kCM1PerLane; using kABPs2RHssMajor = sequence<2, 1>; using kABPs2RHssMinor = sequence<1, 0>; using kABYs2RHsMajor = sequence<2, 2>; using kABYs2RHsMinor = sequence<0, 2>; using kCPs2RHssMajor = sequence<1, 2>; using kCPs2RHssMinor = sequence<1, 0>; using kCYs2RHsMajor = sequence<1, 1>; using kCYs2RHsMinor = sequence<0, 2>; using kCTPs2RHssMajor = sequence<2, 1>; using kCTPs2RHssMinor = sequence<1, 0>; using kCTYs2RHsMajor = sequence<2, 2>; using kCTYs2RHsMinor = sequence<0, 2>; ``` Note that on top of the intrinsic sizes, we have 12 layout parameters. I have reduced this in the new design to: ``` // Basic info using ADataType = void; using BDataType = void; using CDataType = void; // Fragment sizes static constexpr index_t kM; static constexpr index_t kN; static constexpr index_t kK; // Layout parameters static constexpr index_t kABKPerLane; // K2 * K0, Always the same, even for diff A / B layouts static constexpr index_t kAKNumAccess; // K2 static constexpr index_t kARepeat; // Used for RDNA3 repeated inputs and CDNA block hiding. static constexpr index_t kBKNumAccess; // K2 static constexpr index_t kBRepeat; // Used for RDNA3 repeated inputs and CDNA block hiding. static constexpr index_t kCMPerLane; // M2 * M0 static constexpr index_t kCMNumAccess; // M2 // Derived properties using AVecType = ext_vector_t<ADataType, 0>; using BVecType = ext_vector_t<BDataType, 0>; using CVecType = ext_vector_t<CDataType, 0>; ``` Note that there are now only 7 layout parameters and no more dimensionality orderings. Believe it or not these 7 parameters are more general than the original 12, and can handle intrinsic and mid-level features that are currently awkward in CK Tile, like dealing with AttrNumAccess, different A / B layouts, more general block-hiding (currently very limited in CK tile), and future arch features. Furthermore, the A, B and C vec types are now derived directly from the layout parameters to ensure internal consistency. I added a detailed explanation of the new params in terms of register mappings at the top of amgcn_mma.hpp Other refactorings I did in this MR: - Make an amdgcn_mma_base struct to drastically reduce code duplication and potential bugs. Should also make auto-generating the amd_gcn specializations much easier. - Simplify the MmaOpTraits significantly by only including those parameters that are not directly gettable from the MmaOp itself. This removes duplicated variables and simplifies higher level code. - Remove overloaded "Block" term for intrinsic dimensions, and replace by "Frag" instead. Some spots were already using the term "Frag" for combined intrinsics, in which case I changed that term to "Chunk" instead. - Remove some tests that had become somewhat pointless (setting variables and then checking their values immediately). - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
GitHub
parent
0a3229ea22
commit
e8f9bb0c19
@@ -14,6 +14,160 @@
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**---------------------------------------------------
|
||||
* Meaning of amdgcn_mma layout parameters (general)
|
||||
* ---------------------------------------------------
|
||||
*
|
||||
* The fragment (MmaTile) sizes and layout constants in the amdgcn_mma struct describe the mapping
|
||||
* between intrinsic input / output matrix elements and vector registers (lane x vector_item space).
|
||||
* Note that we end up having a mapping for A, B and C separately, although those for A and B are
|
||||
* usually similar if not identical. All mappings can be described as an unmerge operation on one of
|
||||
* the matrix dims (either K for AB or M for C), followed by remerging of the resulting subdims and
|
||||
* raw other dim into the Lane and Vector_item dimensions. When considering an unmerge operation on
|
||||
* a dimension K, we can label the resulting sub-dimensions as K0, K1, and K2, where K0 is the size
|
||||
* of the fastest changing dimension. K0 is also referred to as "The size of the first unmerge", and
|
||||
* K1 would be "The size of the second unmerge". There are never more than 2 unmerge operations, and
|
||||
* unmerge operations may be trivial (unmerge size of 1). Example double unmerge of size {3, 2} of a
|
||||
* K dimension of size 12:
|
||||
*
|
||||
* K K2 K1 K0
|
||||
* 0 0 0 0
|
||||
* 1 0 0 1
|
||||
* 2 0 1 0
|
||||
* 3 0 1 1
|
||||
* 4 0 2 0
|
||||
* 5 0 2 1
|
||||
* 6 1 0 0
|
||||
* 7 1 0 1
|
||||
* 8 1 1 0
|
||||
* 9 1 1 1
|
||||
* 10 1 2 0
|
||||
* 11 1 2 1
|
||||
*
|
||||
* Note that K0 = 2 (first unmerge size, fastest changing), K1 = 3 (second unmerge size,
|
||||
* second-fastest changing), and K2 = 12 / 2 / 3 = 2 (outermost dimension, whatever is left).
|
||||
*
|
||||
* If we were to use this unmerge op to describe an A matrix layout in registers, we might have for
|
||||
* example that L (lane dim) is composed of K1 and M, and V (vector_item dim) is composed of K2 and
|
||||
* K0. Compactly described, this would be K{3, 2} L{K1M} V{K2K0}, and if the M dimension was 2 we
|
||||
* would have the following layout (6 lanes, 4 vector items each):
|
||||
*
|
||||
* | V0 | V1 | V2 | V3 |
|
||||
* L0 | M=0 K=0 | M=0 K=1 | M=0 K=6 | M=0 K=7 |
|
||||
* L1 | M=1 K=0 | M=1 K=1 | M=1 K=6 | M=1 K=7 |
|
||||
* L2 | M=0 K=2 | M=0 K=3 | M=0 K=8 | M=0 K=9 |
|
||||
* L3 | M=1 K=2 | M=1 K=3 | M=1 K=8 | M=1 K=9 |
|
||||
* L4 | M=0 K=4 | M=0 K=5 | M=0 K=10 | M=0 K=11 |
|
||||
* L5 | M=1 K=4 | M=1 K=5 | M=1 K=10 | M=1 K=11 |
|
||||
*
|
||||
* Note that all A matrix elements are now placed in a unique (lane, vector_item). In case a Repeat
|
||||
* dimension is used, every single matrix element is mapped to multiple (Lane, vector_item)
|
||||
* locations, usually along the Lane dimension.
|
||||
*
|
||||
* Check out TileDistrEncRegMap which can print full forward and backward mapping tables for any
|
||||
* register mapping (expressed as a tile distribution encoding).
|
||||
*
|
||||
* ------------------------------------------
|
||||
* Individual amdgcn_mma layout parameters
|
||||
* ------------------------------------------
|
||||
*
|
||||
* -- ABKPerLane --
|
||||
* The number of K dim elements in each lane. Always the same for A and B, even when they have
|
||||
* different layouts. In terms of unmerge sizes, it's equal to K0 * K2, i.e the product of the sizes
|
||||
* of the outermost and innermost dimensions after a double K unmerge.
|
||||
*
|
||||
* -- A / B NumAccess --
|
||||
* These two variables describe the size of the outermost dimension if two unmerge operations are
|
||||
* required for K (so K2). Alternatively it can be described as the number of sets the vector
|
||||
* dimension, which houses a number of K indices, is split up into. We may be able to actually
|
||||
* remove the A / B NumAccess from the amdgcn struct, but it sort of depends on how load and store
|
||||
* tile work and whether we want the mid-level code to always have to know about this. There are
|
||||
* only two reasons for the A / B NumAccess to ever not be 1, and they are different types of
|
||||
* reasons:
|
||||
*
|
||||
* (logical correctness). Applies to scale MFMA fp8, which due to the index matrix layout does not
|
||||
* allow arbitrary K perms to simplify layouts. This means the layout can only properly be described
|
||||
* with a Num Access value of at least 2.
|
||||
*
|
||||
* (load / store manipulation). It seems like the load and store tile functions end up looking for
|
||||
* the size of the smallest unmerged K dimension (K0) to determine how many elements should be
|
||||
* loaded at a time. Different Num Access values will lead to different load / store behavior, even
|
||||
* if logically equivalent.
|
||||
*
|
||||
* -- A / B Repeat --
|
||||
* Variable indicating that all matrix values are represented multiple times in the vector
|
||||
* registers, typically repeating in the lane dimension. This is always equal to the repeat value
|
||||
* used in Tile Distribution encodings. There are two reasons to have non-trivial (non-1) value
|
||||
* here: MFMA block-hiding to create oblong "virtual" intrinsics, and RDNA3 input repetition.
|
||||
*
|
||||
* -- CMPerLane --
|
||||
* The number of M dim elements in each lane. In terms of unmerge sizes, it's equal to M0 * M2, i.e
|
||||
* the product of the sizes of the outermost and innermost dimensions after a double M unmerge.
|
||||
*
|
||||
* -- CNumAccess --
|
||||
* Same as A / B NumAccess but for the M dim (so M2), but the mid-level code doesn't care about this
|
||||
* and will not try to request a specific value. Absolutely needed for logical correctness of
|
||||
* register mappings since we can not perform arbitrary M permutations without messing up the A
|
||||
* layout.
|
||||
*/
|
||||
|
||||
/**
|
||||
* @class amdgcn_mma_base
|
||||
* @brief Base class for amdgcn_mma structs to avoid a lot of code duplication. Also puts
|
||||
* all generic parameter derivations and static asserts in one place. Houses all of the
|
||||
* amdgcn struct types and variables, except for the exec() function.
|
||||
*/
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
uint32_t FragM,
|
||||
uint32_t FragN,
|
||||
uint32_t FragK,
|
||||
uint32_t WaveSize_,
|
||||
index_t kABKPerLane_,
|
||||
index_t kAKNumAccess_,
|
||||
index_t kARepeat_,
|
||||
index_t kBKNumAccess_,
|
||||
index_t kBRepeat_,
|
||||
index_t kCMPerLane_,
|
||||
index_t kCMNumAccess_,
|
||||
typename OpType_,
|
||||
MmaOpFamily OpFamily_>
|
||||
struct amdgcn_mma_base
|
||||
{
|
||||
using OpType = OpType_;
|
||||
static constexpr MmaOpFamily OpFamily = OpFamily_;
|
||||
|
||||
// Data types
|
||||
using ADataType = ADataType_;
|
||||
using BDataType = BDataType_;
|
||||
using CDataType = CDataType_;
|
||||
|
||||
// Fragment (MmaTile) sizes, check description above.
|
||||
static constexpr index_t kM = FragM; // M = M2 * M1 * M0
|
||||
static constexpr index_t kN = FragN;
|
||||
static constexpr index_t kK = FragK; // K = K2 * K1 * K0
|
||||
|
||||
// Layout constants, check description above.
|
||||
static constexpr index_t kABKPerLane = kABKPerLane_; // K2 * K0
|
||||
static constexpr index_t kAKNumAccess = kAKNumAccess_; // K2
|
||||
static constexpr index_t kARepeat = kARepeat_; // RDNA3 repetition and MFMA block-hiding
|
||||
static constexpr index_t kBKNumAccess = kBKNumAccess_; // K2
|
||||
static constexpr index_t kBRepeat = kBRepeat_; // RDNA3 repetition and MFMA block-hiding
|
||||
static constexpr index_t kCMPerLane = kCMPerLane_; // M2 * M0
|
||||
static constexpr index_t kCMNumAccess = kCMNumAccess_; // M2
|
||||
|
||||
// Register types (derived)
|
||||
static constexpr index_t WaveSize = WaveSize_;
|
||||
static_assert((kM * kK * kARepeat) % WaveSize == 0);
|
||||
static_assert((kN * kK * kBRepeat) % WaveSize == 0);
|
||||
static_assert((kM * kN) % WaveSize == 0);
|
||||
|
||||
using AVecType = ext_vector_t<ADataType, kM * kK * kARepeat / WaveSize>;
|
||||
using BVecType = ext_vector_t<BDataType, kN * kK * kBRepeat / WaveSize>;
|
||||
using CVecType = ext_vector_t<CDataType, kM * kN / WaveSize>;
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct Unsupported
|
||||
* @brief Meta-tag to indicate unsupported amdgcn_mma instance.
|
||||
@@ -31,23 +185,24 @@ template <typename MmaOp>
|
||||
concept MmaOpI = requires(MmaOp op) {
|
||||
// Requires an op context
|
||||
typename MmaOp::OpType;
|
||||
typename MmaOp::OpFamily;
|
||||
|
||||
// Captures types for inputs / outputs to mma function
|
||||
typename MmaOp::ADataType;
|
||||
typename MmaOp::BDataType;
|
||||
typename MmaOp::CDataType;
|
||||
typename MmaOp::AVecType;
|
||||
typename MmaOp::BVecType;
|
||||
typename MmaOp::CVecType;
|
||||
|
||||
// Captures CK-specific layout properties
|
||||
{ MmaOp::kAMBlock } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kBNBlock } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kAMLane } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kBNLane } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kABKLane } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kABKPerLane } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kCMLane } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kCNLane } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kCM0PerLane } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kCM1PerLane } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kAKNumAccess } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kARepeat } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kBKNumAccess } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kBRepeat } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kCMPerLane } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kCMNumAccess } -> std::convertible_to<unsigned int>;
|
||||
|
||||
// Static exec function
|
||||
{
|
||||
@@ -69,52 +224,40 @@ concept MmaOpI = requires(MmaOp op) {
|
||||
* @tparam ADataType Datatype of input A
|
||||
* @tparam BDataType Datatype of input B
|
||||
* @tparam CDataType Datatype of accumulator
|
||||
* @tparam BlockM M-dimension of mma block
|
||||
* @tparam BlockN N-dimension of mma block
|
||||
* @tparam BlockK K-dimension of mma block
|
||||
* @tparam FragM M-dimension of mma intrinsic (MmaTile)
|
||||
* @tparam FragN N-dimension of mma intrinsic (MmaTile)
|
||||
* @tparam FragK K-dimension of mma intrinsic (MmaTile)
|
||||
* @tparam CtrlFlags Control flags for mma operation
|
||||
* @tparam CompilerTarget The current compiler target
|
||||
* @tparam OpFamily_ The type of operation (dense, sparse, scale, etc.)
|
||||
* @tparam Enabler SFINAE enabler
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t BlockM,
|
||||
uint32_t BlockN,
|
||||
uint32_t BlockK,
|
||||
uint32_t FragM,
|
||||
uint32_t FragN,
|
||||
uint32_t FragK,
|
||||
typename CtrlFlags,
|
||||
typename CompilerTarget,
|
||||
MmaOpFamily OpFamily_,
|
||||
typename Enabler = void>
|
||||
struct amdgcn_mma
|
||||
// clang-format off
|
||||
// | A B C DataTypes |MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma : amdgcn_mma_base<fp32_t, fp32_t, fp32_t, 1u, 1u, 1u, 1u, 1, 1, 1, 1, 1, 1, 1, Unsupported, MmaOpFamily::UNDEFINED>
|
||||
// clang-format on
|
||||
{
|
||||
// The base instance is unsupported because there is no __builtin to wrap.
|
||||
using OpType = Unsupported;
|
||||
static constexpr MmaOpFamily OpFamily = MmaOpFamily::UNDEFINED;
|
||||
|
||||
// Interface types for A, B, C vectors types
|
||||
using AVecType = ext_vector_t<ADataType, 1>;
|
||||
using BVecType = ext_vector_t<BDataType, 1>;
|
||||
using CVecType = ext_vector_t<CDataType, 1>;
|
||||
|
||||
// Layout constants - default to 0
|
||||
static constexpr index_t kAMBlock = 0;
|
||||
static constexpr index_t kBNBlock = 0;
|
||||
|
||||
static constexpr index_t kAMLane = 0;
|
||||
static constexpr index_t kBNLane = 0;
|
||||
static constexpr index_t kABKLane = 0;
|
||||
static constexpr index_t kABKPerLane = 0;
|
||||
|
||||
static constexpr index_t kCMLane = 0;
|
||||
static constexpr index_t kCNLane = 0;
|
||||
static constexpr index_t kCM0PerLane = 0;
|
||||
static constexpr index_t kCM1PerLane = 0;
|
||||
|
||||
// This is a default pass-through implementation that doesn't do anything practical.
|
||||
CK_TILE_DEVICE static CVecType const&
|
||||
exec(AVecType const& regsA, BVecType const& regsB, CVecType const& regsC)
|
||||
{
|
||||
// Prints once across all thread blocks and threads.
|
||||
static __device__ int printed = 0;
|
||||
if(threadIdx.x == 0 && atomicCAS(&printed, 0, 1) == 0)
|
||||
{
|
||||
printf("[WARNING] Running amdgcn_mma dummy exec function!\n");
|
||||
}
|
||||
|
||||
ignore(regsA, regsB);
|
||||
return regsC; // No-op, just return C
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ namespace ck_tile::core::arch::mma {
|
||||
* @brief Specialization of amdgcn_mma for MFMA on GFX9 targets
|
||||
*
|
||||
* This specialization implements the MFMA instruction for fp16_t A and B
|
||||
* matrices, and fp32_t accumulator matrix, with 16x16x16 block sizes.
|
||||
* matrices, and fp32_t accumulator matrix, with 16x16x16 fragment sizes.
|
||||
*
|
||||
* @tparam CtrlFlags Control flags for the MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
@@ -33,40 +33,12 @@ namespace ck_tile::core::arch::mma {
|
||||
// TODO: c++20 template <CtrlFlagsGfx9I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
struct amdgcn_mma<fp16_t,
|
||||
fp16_t,
|
||||
fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::DENSE,
|
||||
enable_if_target_family_gfx9_t<CompilerTarget>>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx9_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, 64u, 4, 1, 1, 1, 1, 4, 1, MfmaOp, MmaOpFamily::DENSE>
|
||||
// clang-format on
|
||||
{
|
||||
// Mfma operation type
|
||||
using OpType = MfmaOp;
|
||||
static constexpr MmaOpFamily OpFamily = MmaOpFamily::DENSE;
|
||||
|
||||
// Register types
|
||||
using AVecType = ext_vector_t<fp16_t, 4>;
|
||||
using BVecType = ext_vector_t<fp16_t, 4>;
|
||||
using CVecType = ext_vector_t<fp32_t, 4>;
|
||||
|
||||
// Layout constants
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABKLane = 4;
|
||||
static constexpr index_t kABKPerLane = 4;
|
||||
|
||||
static constexpr index_t kCMLane = 4;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
static constexpr index_t kCM0PerLane = 1;
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
|
||||
{
|
||||
@@ -84,7 +56,7 @@ struct amdgcn_mma<fp16_t,
|
||||
* @brief Specialization of amdgcn_mma for MFMA on GFX950 targets
|
||||
*
|
||||
* This specialization implements the MFMA instruction for fp16_t A and B
|
||||
* matrices, and fp32_t accumulator matrix, with 16x16x32 block sizes.
|
||||
* matrices, and fp32_t accumulator matrix, with 16x16x32 fragment sizes.
|
||||
*
|
||||
* @tparam CtrlFlags Control flags for the MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
@@ -92,39 +64,12 @@ struct amdgcn_mma<fp16_t,
|
||||
// TODO: c++20 template <CtrlFlagsGfx9I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
struct amdgcn_mma<fp16_t,
|
||||
fp16_t,
|
||||
fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
32u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::DENSE,
|
||||
enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, 64u, 8, 1, 1, 1, 1, 4, 1, MfmaOp, MmaOpFamily::DENSE>
|
||||
// clang-format on
|
||||
{
|
||||
using OpType = MfmaOp;
|
||||
static constexpr MmaOpFamily OpFamily = MmaOpFamily::DENSE;
|
||||
|
||||
// Packed register types
|
||||
using AVecType = ext_vector_t<fp16_t, 8>;
|
||||
using BVecType = ext_vector_t<fp16_t, 8>;
|
||||
using CVecType = ext_vector_t<fp32_t, 4>;
|
||||
|
||||
// Layout constants
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABKLane = 8;
|
||||
static constexpr index_t kABKPerLane = 8;
|
||||
|
||||
static constexpr index_t kCMLane = 4;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
static constexpr index_t kCM0PerLane = 1;
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
|
||||
{
|
||||
|
||||
@@ -18,28 +18,27 @@ namespace ck_tile::core::arch::mma {
|
||||
* @class MfmaDefaultSelector
|
||||
* @brief Implements a default MFMA selector strategy for gfx9 target architectures.
|
||||
* This implements the K dimension search strategy to find the largest supported MFMA
|
||||
* instruction for the given M/N block sizes and datatypes.
|
||||
* If no supported instruction is found, falls back to an unsupported pass-through
|
||||
implementation.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam BlockM Block M dimension size
|
||||
* @tparam BlockN Block N dimension size
|
||||
* @tparam BlockKTest Current Block K dimension size to test
|
||||
* instruction for the given M/N WaveTile sizes and datatypes.
|
||||
* If no supported instruction is found, falls back to an unsupported pass-through implementation.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam WaveTileM WaveTile M dimension size
|
||||
* @tparam WaveTileN WaveTile N dimension size
|
||||
* @tparam WaveTileKTest Current WaveTile K dimension size to test
|
||||
* @tparam CompilerTarget The compiler target
|
||||
* @note Here we assume that BlockKTest is always a power-of-two integer.
|
||||
* The search strategy starts from a maximum BlockKTest size down to 1u by halving
|
||||
* @note Here we assume that WaveTileKTest is always a power-of-two integer.
|
||||
* The search strategy starts from a maximum WaveTileKTest size down to 1u by halving
|
||||
* each time.
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t BlockM,
|
||||
uint32_t BlockN,
|
||||
uint32_t BlockKTest,
|
||||
uint32_t WaveTileM,
|
||||
uint32_t WaveTileN,
|
||||
uint32_t WaveTileKTest,
|
||||
typename CompilerTarget> // TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
// TODO: c++20 requires(is_gfx9_arch_id(CompilerTarget) && is_power_of_two_integer(BlockKTest))
|
||||
// TODO: c++20 requires(is_gfx9_arch_id(CompilerTarget) && is_power_of_two_integer(WaveTileKTest))
|
||||
struct MfmaDefaultSelector
|
||||
{
|
||||
private:
|
||||
@@ -48,26 +47,25 @@ struct MfmaDefaultSelector
|
||||
amdgcn_mma<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
BlockKTest,
|
||||
WaveTileM,
|
||||
WaveTileN,
|
||||
WaveTileKTest,
|
||||
DefaultMfmaCtrlFlags, // By default, let's assume no special flags for MFMA
|
||||
CompilerTarget,
|
||||
MmaOpFamily::DENSE>;
|
||||
using CandidateTraits = MmaOpTraits<CandidateOp>;
|
||||
|
||||
public:
|
||||
// If the candidate is supported (e.g., a backend implementation exists), then select it.
|
||||
// Otherwise, test another smaller BlockK. If no existing implementations, we will get BlockK=0u
|
||||
// and fall back to the unsupported pass-through implementation.
|
||||
using SelectedOp = std::conditional_t<CandidateTraits::IsSupported,
|
||||
// Otherwise, test another smaller WaveTileK. If no existing implementations, we will get
|
||||
// WaveTileK=0u and fall back to the unsupported pass-through implementation.
|
||||
using SelectedOp = std::conditional_t<MmaOpTraits<CandidateOp>::IsSupported,
|
||||
CandidateOp,
|
||||
typename MfmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
BlockKTest / 2u,
|
||||
WaveTileM,
|
||||
WaveTileN,
|
||||
WaveTileKTest / 2u,
|
||||
CompilerTarget>::SelectedOp>;
|
||||
};
|
||||
|
||||
@@ -75,28 +73,34 @@ struct MfmaDefaultSelector
|
||||
* @struct MfmaDefaultSelector
|
||||
* @brief Implements the base case for the default MFMA selector when no supported instruction is
|
||||
* found.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam BlockM Block M dimension size
|
||||
* @tparam BlockN Block N dimension size
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam WaveTileM WaveTile M dimension size
|
||||
* @tparam WaveTileN WaveTile N dimension size
|
||||
* @tparam CompilerTarget The compiler target
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t BlockM,
|
||||
uint32_t BlockN,
|
||||
uint32_t WaveTileM,
|
||||
uint32_t WaveTileN,
|
||||
typename CompilerTarget> // TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
struct MfmaDefaultSelector<ADataType, BDataType, CDataType, BlockM, BlockN, 1u, CompilerTarget>
|
||||
struct MfmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
WaveTileM,
|
||||
WaveTileN,
|
||||
1u,
|
||||
CompilerTarget>
|
||||
{
|
||||
// Default unsupported pass-through if no instruction is found
|
||||
using SelectedOp =
|
||||
amdgcn_mma<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
WaveTileM,
|
||||
WaveTileN,
|
||||
1u,
|
||||
DefaultMfmaCtrlFlags, // By default, let's assume no special flags for MFMA
|
||||
CompilerTarget,
|
||||
@@ -106,32 +110,32 @@ struct MfmaDefaultSelector<ADataType, BDataType, CDataType, BlockM, BlockN, 1u,
|
||||
/**
|
||||
* @struct MmaDefaultSelector
|
||||
* @brief Implements the gfx9 default MMA selector strategy for wave-wise MMA decomposition.
|
||||
* This implements the M/N block size search strategy to find the largest supported MFMA
|
||||
* This implements the M/N WaveTile size search strategy to find the largest supported MFMA
|
||||
* instruction for the given datatypes.
|
||||
* If no supported instruction is found, falls back to an unsupported pass-through implementation.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam FragM Size of the M dimension of the fragment to decompose
|
||||
* @tparam FragN Size of the N dimension of the fragment to decompose
|
||||
* @tparam FragK Size of the K dimension of the fragment to decompose
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam WaveTileM Size of the M dimension of the WaveTile to decompose
|
||||
* @tparam WaveTileN Size of the N dimension of the WaveTile to decompose
|
||||
* @tparam WaveTileK Size of the K dimension of the WaveTile to decompose
|
||||
* @tparam CompilerTarget The compiler target
|
||||
* @tparam OpFamily The MMA operation family
|
||||
* @tparam OpFamily The MMA operation family
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t FragM,
|
||||
uint32_t FragN,
|
||||
uint32_t FragK,
|
||||
uint32_t WaveTileM,
|
||||
uint32_t WaveTileN,
|
||||
uint32_t WaveTileK,
|
||||
typename CompilerTarget,
|
||||
MmaOpFamily OpFamily> // TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
struct MmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
FragM,
|
||||
FragN,
|
||||
FragK,
|
||||
WaveTileM,
|
||||
WaveTileN,
|
||||
WaveTileK,
|
||||
CompilerTarget,
|
||||
OpFamily,
|
||||
enable_if_all<enable_if_target_family_gfx9_t<CompilerTarget>,
|
||||
@@ -163,27 +167,20 @@ struct MmaDefaultSelector<ADataType,
|
||||
typename MfmaDefaultSelector<ADataType, BDataType, CDataType, 1u, 1u, 1u, CompilerTarget>::
|
||||
SelectedOp;
|
||||
|
||||
// Traits for each candidate
|
||||
using CandidateTraits4x4 = MmaOpTraits<CandidateOp4x4>;
|
||||
using CandidateTraits16x16 = MmaOpTraits<CandidateOp16x16>;
|
||||
using CandidateTraits32x32 = MmaOpTraits<CandidateOp32x32>;
|
||||
|
||||
// Check if each candidate is supported for the given fragment sizes
|
||||
// For this case, we require the fragment sizes to be multiples of the MFMA shape
|
||||
// Check if each candidate is supported for the given WaveTile sizes
|
||||
// For this case, we require the WaveTile sizes to be multiples of the MFMA shape
|
||||
static constexpr bool IsSupported4x4 =
|
||||
CandidateTraits4x4::IsSupported && (FragM % CandidateTraits4x4::BlockM == 0u) &&
|
||||
(FragN % CandidateTraits4x4::BlockN == 0u) && (FragK % CandidateTraits4x4::BlockK == 0u);
|
||||
static constexpr bool IsSupported16x16 = CandidateTraits16x16::IsSupported &&
|
||||
(FragM % CandidateTraits16x16::BlockM == 0u) &&
|
||||
(FragN % CandidateTraits16x16::BlockN == 0u) &&
|
||||
(FragK % CandidateTraits16x16::BlockK == 0u);
|
||||
static constexpr bool IsSupported32x32 = CandidateTraits32x32::IsSupported &&
|
||||
(FragM % CandidateTraits32x32::BlockM == 0u) &&
|
||||
(FragN % CandidateTraits32x32::BlockN == 0u) &&
|
||||
(FragK % CandidateTraits32x32::BlockK == 0u);
|
||||
MmaOpTraits<CandidateOp4x4>::IsSupported && (WaveTileM % CandidateOp4x4::kM == 0u) &&
|
||||
(WaveTileN % CandidateOp4x4::kN == 0u) && (WaveTileK % CandidateOp4x4::kK == 0u);
|
||||
static constexpr bool IsSupported16x16 =
|
||||
MmaOpTraits<CandidateOp16x16>::IsSupported && (WaveTileM % CandidateOp16x16::kM == 0u) &&
|
||||
(WaveTileN % CandidateOp16x16::kN == 0u) && (WaveTileK % CandidateOp16x16::kK == 0u);
|
||||
static constexpr bool IsSupported32x32 =
|
||||
MmaOpTraits<CandidateOp32x32>::IsSupported && (WaveTileM % CandidateOp32x32::kM == 0u) &&
|
||||
(WaveTileN % CandidateOp32x32::kN == 0u) && (WaveTileK % CandidateOp32x32::kK == 0u);
|
||||
|
||||
public:
|
||||
// Select the largest supported MFMA operation for the given fragment shape
|
||||
// Select the largest supported MFMA operation for the given WaveTile shape
|
||||
using SelectedOp = std::conditional_t<
|
||||
IsSupported32x32,
|
||||
CandidateOp32x32,
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
|
||||
#include "amdgcn_mma.hpp"
|
||||
#include "mma_selector.hpp"
|
||||
#include "mma_traits.hpp"
|
||||
#include "mma_transforms.hpp"
|
||||
|
||||
#include "mfma/mfma.hpp"
|
||||
@@ -19,44 +18,42 @@ namespace ck_tile::core::arch::mma {
|
||||
*/
|
||||
enum struct MmaAccumPolicy
|
||||
{
|
||||
// Decomposition and accumulation in row-major block order
|
||||
// Decomposition and accumulation in row-major fragment order
|
||||
ROW_MAJOR,
|
||||
// Decomposition and accumulation in col-major block order
|
||||
// Decomposition and accumulation in col-major fragment order
|
||||
COL_MAJOR
|
||||
};
|
||||
|
||||
/**
|
||||
* @class Mma
|
||||
* @brief Driver for the wave-tile Mma operation. Given a backend block-wise MmaOp implementation
|
||||
* (e.g., mfma or wmma), this class performs block-wise decomposition to matrix-multiply input
|
||||
* fragments of (A: FragM x FragK) x (B: FragK x FragN) and accumulates results into output fragment
|
||||
* (C: FragM x FragN).
|
||||
* @tparam ADataType Data type of input fragment A
|
||||
* @tparam BDataType Data type of input fragment B
|
||||
* @tparam CDataType Data type of input/output fragment C (accumulator)
|
||||
* @tparam FragM Mma fragment M dimension
|
||||
* @tparam FragN Mma fragment K dimension
|
||||
* @tparam FragK Mma fragment M dimension
|
||||
* @tparam AccumPolicy The block order of the accumulation registers (row major or col major block
|
||||
* order)
|
||||
* @brief Driver for the wave-tile Mma operation. Given a backend MmaOp implementation
|
||||
* (e.g., mfma or wmma), this class performs fragment-wise (MmaTile) decomposition to
|
||||
* matrix-multiply input WaveTiles of (A: WaveTileM x WaveTileK) x (B: WaveTileK x WaveTileN) and
|
||||
* accumulates results into output WaveTile (C: WaveTileM x WaveTileN).
|
||||
* @tparam ADataType Data type of input WaveTile A
|
||||
* @tparam BDataType Data type of input WaveTile B
|
||||
* @tparam CDataType Data type of input/output WaveTile C (accumulator)
|
||||
* @tparam WaveTileM Mma WaveTile M dimension
|
||||
* @tparam WaveTileN Mma WaveTile K dimension
|
||||
* @tparam WaveTileK Mma WaveTile M dimension
|
||||
* @tparam AccumPolicy The fragment order of the accum. registers (row or col major frag order)
|
||||
* @tparam CompilerTarget The compiler target
|
||||
* @tparam MmaOp The backend wrapper class that will perform block-wise mma op (e.g., mfma or
|
||||
* wmma)
|
||||
* @tparam MmaTransforms The set of transforms to be applied to input/output fragments
|
||||
* @tparam MmaOp Backend wrapper class that will perform the mma op (e.g., mfma or wmma)
|
||||
* @tparam MmaTransforms The set of transforms to be applied to input/output WaveTiles
|
||||
* @par This is an example of an Mma decomposition driver class that can be used in a wave-tile
|
||||
* context. Given a fragment size, we can decompose the fragment into smaller block-wise mma ops
|
||||
* context. Given a WaveTile size, we can decompose the WaveTile into smaller mma op fragments
|
||||
* that are natively supported by the hardware (e.g., mfma or wmma). The class also supports
|
||||
* applying transforms to the input/output fragments as needed (e.g., layout conversions, data type
|
||||
* applying transforms to the input/output frags as needed (e.g., layout conversions, data type
|
||||
* conversions, etc.). We may also specify the accumulation order (row-major or col-major) for the
|
||||
* output fragment. This is a powerful example of how to build a flexible and reusable mma driver
|
||||
* output WaveTile. This is a powerful example of how to build a flexible and reusable mma driver
|
||||
* that can adapt to different hardware capabilities and requirements.
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t FragM,
|
||||
uint32_t FragN,
|
||||
uint32_t FragK,
|
||||
uint32_t WaveTileM,
|
||||
uint32_t WaveTileN,
|
||||
uint32_t WaveTileK,
|
||||
MmaOpFamily OpFamily,
|
||||
MmaAccumPolicy AccumPolicy = MmaAccumPolicy::ROW_MAJOR,
|
||||
typename CompilerTarget =
|
||||
@@ -67,39 +64,37 @@ template <typename ADataType,
|
||||
// MmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
FragM,
|
||||
FragN,
|
||||
FragK,
|
||||
WaveTileM,
|
||||
WaveTileN,
|
||||
WaveTileK,
|
||||
CompilerTarget,
|
||||
OpFamily>::SelectedOp,
|
||||
typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms =
|
||||
typename MmaTransformsDefaultSelector<MmaOp, CompilerTarget>::SelectedTransforms>
|
||||
struct WaveWiseMma
|
||||
{
|
||||
using FragWiseMmaOp = MmaOp;
|
||||
|
||||
using BlockWiseMmaOp = MmaOp;
|
||||
using BlockWiseMmaOpTraits = MmaOpTraits<BlockWiseMmaOp>;
|
||||
// Fragment dimensions
|
||||
constexpr static uint32_t FragM = MmaOp::kM;
|
||||
constexpr static uint32_t FragN = MmaOp::kN;
|
||||
constexpr static uint32_t FragK = MmaOp::kK;
|
||||
|
||||
// Block dimensions
|
||||
constexpr static uint32_t BlockM = BlockWiseMmaOpTraits::BlockM;
|
||||
constexpr static uint32_t BlockN = BlockWiseMmaOpTraits::BlockN;
|
||||
constexpr static uint32_t BlockK = BlockWiseMmaOpTraits::BlockK;
|
||||
// Fragment counts for decomposition
|
||||
constexpr static uint32_t FragsM = WaveTileM / FragM;
|
||||
constexpr static uint32_t FragsN = WaveTileN / FragN;
|
||||
constexpr static uint32_t FragsK = WaveTileK / FragK;
|
||||
constexpr static uint32_t FragsC = FragsM * FragsN;
|
||||
|
||||
// Block counts for decomposition
|
||||
constexpr static uint32_t BlocksM = FragM / BlockM;
|
||||
constexpr static uint32_t BlocksN = FragN / BlockN;
|
||||
constexpr static uint32_t BlocksK = FragK / BlockK;
|
||||
constexpr static uint32_t BlocksC = BlocksM * BlocksN;
|
||||
// Vector types for packed registers in each fragment
|
||||
using AVecType = typename MmaOp::AVecType;
|
||||
using BVecType = typename MmaOp::BVecType;
|
||||
using CVecType = typename MmaOp::CVecType;
|
||||
|
||||
// Vector types for packed registers in each block
|
||||
using AVecType = typename BlockWiseMmaOpTraits::AVecType;
|
||||
using BVecType = typename BlockWiseMmaOpTraits::BVecType;
|
||||
using CVecType = typename BlockWiseMmaOpTraits::CVecType;
|
||||
|
||||
// Buffer types for fragments
|
||||
using ABufferType = AVecType[BlocksM][BlocksK];
|
||||
using BBufferType = BVecType[BlocksN][BlocksK];
|
||||
using CBufferType = CVecType[BlocksM][BlocksN];
|
||||
// Buffer types for WaveTiles
|
||||
using ABufferType = AVecType[FragsM][FragsK];
|
||||
using BBufferType = BVecType[FragsN][FragsK];
|
||||
using CBufferType = CVecType[FragsM][FragsN];
|
||||
|
||||
// Transforms
|
||||
using ATransform = typename MmaTransforms::ATransform;
|
||||
@@ -108,20 +103,20 @@ struct WaveWiseMma
|
||||
using DTransform = typename MmaTransforms::DTransform;
|
||||
|
||||
// Sanity checks
|
||||
static_assert(FragM >= BlockM, "FragM must be larger than BlockM");
|
||||
static_assert(FragN >= BlockN, "FragN must be larger than BlockN");
|
||||
static_assert(FragK >= BlockK, "FragK must be larger than BlockK");
|
||||
static_assert(FragM % BlockM == 0u, "FragM must be a multiple of BlockM");
|
||||
static_assert(FragN % BlockN == 0u, "FragN must be a multiple of BlockN");
|
||||
static_assert(FragK % BlockK == 0u, "FragK must be a multiple of BlockK");
|
||||
static_assert(WaveTileM >= FragM, "WaveTileM must be larger than FragM");
|
||||
static_assert(WaveTileN >= FragN, "WaveTileN must be larger than FragN");
|
||||
static_assert(WaveTileK >= FragK, "WaveTileK must be larger than FragK");
|
||||
static_assert(WaveTileM % FragM == 0u, "WaveTileM must be a multiple of FragM");
|
||||
static_assert(WaveTileN % FragN == 0u, "WaveTileN must be a multiple of FragN");
|
||||
static_assert(WaveTileK % FragK == 0u, "WaveTileK must be a multiple of FragK");
|
||||
|
||||
private:
|
||||
template <typename DstT, typename SrcT>
|
||||
CK_TILE_DEVICE static auto formatBuffer(SrcT const& inputBuffer)
|
||||
{
|
||||
// TODO: Implement formatting logic as needed.
|
||||
// This is intended to convert input fragments to the native vector types
|
||||
// required by the BlockWiseMma operation for iteration
|
||||
// This is intended to convert input WaveTiles to the native vector types
|
||||
// required by the FragWiseMma operation for iteration
|
||||
static_assert(sizeof(DstT) == sizeof(SrcT), "Size mismatch in formatBuffer");
|
||||
return reinterpret_cast<DstT const&>(inputBuffer);
|
||||
}
|
||||
@@ -130,16 +125,16 @@ struct WaveWiseMma
|
||||
CK_TILE_DEVICE static auto formatBuffer(SrcT& inputBuffer)
|
||||
{
|
||||
// TODO: Implement formatting logic as needed.
|
||||
// This is intended to convert input fragments to the native vector types
|
||||
// required by the BlockWiseMma operation for iteration
|
||||
// This is intended to convert input WaveTiles to the native vector types
|
||||
// required by the FragWiseMma operation for iteration
|
||||
static_assert(sizeof(DstT) == sizeof(SrcT), "Size mismatch in formatBuffer");
|
||||
return reinterpret_cast<DstT&>(inputBuffer);
|
||||
}
|
||||
|
||||
/*! @brief Execute Mma in row-major accumulation order.
|
||||
* @tparam VecTA The input fragment A vector type
|
||||
* @tparam VecTB The input fragment B vector type
|
||||
* @tparam VecTC The input/output fragment C vector type
|
||||
* @tparam VecTA The input WaveTile A vector type
|
||||
* @tparam VecTB The input WaveTile B vector type
|
||||
* @tparam VecTC The input/output WaveTile C vector type
|
||||
*/
|
||||
template <typename VecTA, typename VecTB, typename VecTC>
|
||||
CK_TILE_DEVICE static decltype(auto) exec_col_major(VecTA&& a, VecTB&& b, VecTC&& accum)
|
||||
@@ -153,35 +148,35 @@ struct WaveWiseMma
|
||||
auto b_frag = formatBuffer<BBufferType>(BTransform::exec(b));
|
||||
auto c_frag = formatBuffer<CBufferType>(CTransform::exec(accum));
|
||||
|
||||
// "Col-major" accumulation over the M-dimension blocks first.
|
||||
// Pseudo code here, but we would basically iterate over the blocks in col-major order
|
||||
for(uint32_t bn = 0u; bn < BlocksN; ++bn)
|
||||
// "Col-major" accumulation over the M-dimension fragments first.
|
||||
// Pseudo code here, but we would basically iterate over the fragments in col-major order
|
||||
for(uint32_t bn = 0u; bn < FragsN; ++bn)
|
||||
{
|
||||
for(uint32_t bm = 0u; bm < BlocksM; ++bm)
|
||||
for(uint32_t bm = 0u; bm < FragsM; ++bm)
|
||||
{
|
||||
for(uint32_t bk = 0u; bk < BlocksK; ++bk)
|
||||
for(uint32_t bk = 0u; bk < FragsK; ++bk)
|
||||
{
|
||||
c_frag[bm][bn] =
|
||||
BlockWiseMmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]);
|
||||
FragWiseMmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert native vector results back to the output fragment format
|
||||
// Convert native vector results back to the output WaveTile format
|
||||
// and then return after we apply the final output transform.
|
||||
return DTransform::exec(formatBuffer<std::decay_t<VecTC>>(c_frag));
|
||||
}
|
||||
|
||||
/*! @brief Execute Mma in row-major accumulation order.
|
||||
* @tparam VecTA The input fragment A vector type
|
||||
* @tparam VecTB The input fragment B vector type
|
||||
* @tparam VecTC The input/output fragment C vector type
|
||||
* @tparam VecTA The input WaveTile A vector type
|
||||
* @tparam VecTB The input WaveTile B vector type
|
||||
* @tparam VecTC The input/output WaveTile C vector type
|
||||
*/
|
||||
template <typename VecTA, typename VecTB, typename VecTC>
|
||||
CK_TILE_DEVICE static decltype(auto) exec_row_major(VecTA&& a, VecTB&& b, VecTC&& accum)
|
||||
{
|
||||
// We implement an example wave-tile pipeline here.
|
||||
// First, we apply the necessary transforms to the input fragments,
|
||||
// First, we apply the necessary transforms to the input WaveTiles,
|
||||
// then we convert the result into buffers of native vector formats
|
||||
// that we can easily index. Native vector formats are necessary inputs
|
||||
// to the given MmaOp exec function.
|
||||
@@ -189,32 +184,32 @@ struct WaveWiseMma
|
||||
auto b_frag = formatBuffer<BBufferType>(BTransform::exec(b));
|
||||
auto c_frag = formatBuffer<CBufferType>(CTransform::exec(accum));
|
||||
|
||||
// "Row-major" accumulation over the N-dimension blocks first.
|
||||
// Pseudo code here, but we would basically iterate over the blocks in row-major order.
|
||||
// We also have to ensure that the incoming vector fragments are converted to native vector
|
||||
// types before passing to the BlockWiseMma exec function.
|
||||
for(uint32_t bm = 0u; bm < BlocksM; ++bm)
|
||||
// "Row-major" accumulation over the N-dimension fragments first.
|
||||
// Pseudo code here, but we would basically iterate over the fragments in row-major order.
|
||||
// We also have to ensure that the incoming vector WaveTiles are converted to native vector
|
||||
// types before passing to the FragWiseMma exec function.
|
||||
for(uint32_t bm = 0u; bm < FragsM; ++bm)
|
||||
{
|
||||
for(uint32_t bn = 0u; bn < BlocksN; ++bn)
|
||||
for(uint32_t bn = 0u; bn < FragsN; ++bn)
|
||||
{
|
||||
for(uint32_t bk = 0u; bk < BlocksK; ++bk)
|
||||
for(uint32_t bk = 0u; bk < FragsK; ++bk)
|
||||
{
|
||||
c_frag[bm][bn] =
|
||||
BlockWiseMmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]);
|
||||
FragWiseMmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert native vector results back to the output fragment format
|
||||
// Convert native vector results back to the output WaveTile format
|
||||
// and then return after we apply the final output transform.
|
||||
return DTransform::exec(formatBuffer<std::decay_t<VecTC>>(c_frag));
|
||||
}
|
||||
|
||||
public:
|
||||
/*! @brief Forward to Mma operation with specified accumulation order.
|
||||
* @tparam VecTA The input fragment A vector type
|
||||
* @tparam VecTB The input fragment B vector type
|
||||
* @tparam VecTC The input/output fragment C vector type
|
||||
* @tparam VecTA The input WaveTile A vector type
|
||||
* @tparam VecTB The input WaveTile B vector type
|
||||
* @tparam VecTC The input/output WaveTile C vector type
|
||||
*/
|
||||
template <typename VecTA, typename VecTB, typename VecTC>
|
||||
CK_TILE_DEVICE static decltype(auto) exec(VecTA&& a, VecTB&& b, VecTC&& accum)
|
||||
|
||||
@@ -9,32 +9,31 @@ namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @class MmaDefaultSelector
|
||||
* @brief Implements a default mma selector strategy for the current target architecture.
|
||||
* This is simply intended as a default selection strategy for mma instruction operations.
|
||||
* Given the particular datatypes and Fragment dimensions, the selector will attempt to
|
||||
* select the instruction with the largest K dimension that is supported on the current target
|
||||
* architecture.
|
||||
* @brief Implements a default mma selector strategy for the current target architecture. This is
|
||||
* simply intended as a default selection strategy for mma instruction operations. Given the
|
||||
* particular datatypes and WaveTile dimensions, the selector will attempt to select the instruction
|
||||
* with the largest K dimension that is supported on the current target architecture.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam FragM Fragment M dimension
|
||||
* @tparam FragN Fragment N dimension
|
||||
* @tparam FragK Fragment K dimension
|
||||
* @tparam WaveTileM WaveTile M dimension
|
||||
* @tparam WaveTileN WaveTile N dimension
|
||||
* @tparam WaveTileK WaveTile K dimension
|
||||
* @tparam CompilerTarget The compiler target
|
||||
* @tparam OpFamily The MMA operation family
|
||||
* @tparam Enable SFINAE enabler
|
||||
* @note Here we distinguish that Fragment MNK sizes from Block MNK sizes used in the actual MMA
|
||||
* operation. Fragment sizes correspond to the overall tile size being computed, while Block sizes
|
||||
* correspond to the size of the individual MMA instructions being used to compute the overall in
|
||||
* block-wise. The Fragment sizes must be multiples of the Block sizes and in general larger than or
|
||||
* equal to the Block sizes.
|
||||
* @note Here we distinguish that WaveTile MNK sizes from Fragment MNK sizes used in the actual MMA
|
||||
* operation. WaveTile sizes correspond to the overall tile size being computed, while Fragment
|
||||
* sizes correspond to the size of the individual MMA instructions being used to compute the overall
|
||||
* in fragment-wise. The WaveTile sizes must be multiples of the Fragment sizes and in general
|
||||
* larger than or equal to the Fragment sizes.
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t FragM,
|
||||
uint32_t FragN,
|
||||
uint32_t FragK,
|
||||
uint32_t WaveTileM,
|
||||
uint32_t WaveTileN,
|
||||
uint32_t WaveTileK,
|
||||
typename CompilerTarget,
|
||||
MmaOpFamily OpFamily,
|
||||
typename Enable = void>
|
||||
@@ -46,9 +45,9 @@ struct MmaDefaultSelector
|
||||
using SelectedOp = amdgcn_mma<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
FragM,
|
||||
FragN,
|
||||
FragK,
|
||||
WaveTileM,
|
||||
WaveTileN,
|
||||
WaveTileK,
|
||||
void,
|
||||
amdgcn_target<>,
|
||||
MmaOpFamily::UNDEFINED>;
|
||||
|
||||
@@ -44,122 +44,64 @@ struct is_mma_op_supported<MmaOp,
|
||||
template <typename MmaOp>
|
||||
static constexpr bool is_mma_op_supported_v = is_mma_op_supported<MmaOp>::value;
|
||||
|
||||
/**
|
||||
* @class MmaOpParams
|
||||
* @brief Reflects the template parameters of a given MmaOp
|
||||
* @tparam MmaOp The matrix multiply-accumulate operation type to check
|
||||
*/
|
||||
// TODO: c++20 template <MmaOpI MmaOp>
|
||||
template <typename MmaOp>
|
||||
struct MmaOpParams;
|
||||
|
||||
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
#include <concepts>
|
||||
template <typename T>
|
||||
struct MmaOpTraits;
|
||||
|
||||
/**
|
||||
* @concept MmaOpParamsI
|
||||
* @brief Expresses the required members for each MmaOp
|
||||
*/
|
||||
template <typename MmaOpParams>
|
||||
concept MmaOpParamsI = requires(MmaOpParams op) {
|
||||
// Capture template parameters
|
||||
typename MmaOpParams::ADataType;
|
||||
typename MmaOpParams::BDataType;
|
||||
typename MmaOpParams::CDataType;
|
||||
typename MmaOpParams::CtrlFlags;
|
||||
|
||||
{ MmaOpParams::BlockM } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOpParams::BlockN } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOpParams::BlockK } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOpParams::GfxTargetId } -> std::convertible_to<amdgcn_target_arch_id>;
|
||||
{ MmaOpParams::Family } -> std::convertible_to<MmaOpFamily>;
|
||||
};
|
||||
|
||||
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
|
||||
/**
|
||||
* @struct MmaOpParams
|
||||
* @brief Reflects the template parameters of a given MmaOp
|
||||
* @struct MmaOpTraits
|
||||
* @brief Gives additional traits and unexposed template parameters of a given MmaOp
|
||||
* @tparam ADataType_ Data type of matrix A
|
||||
* @tparam BDataType_ Data type of matrix B
|
||||
* @tparam CDataType_ Data type of the accumulator
|
||||
* @tparam BlockM_ Size of the M dimension
|
||||
* @tparam BlockN_ Size of the N dimension
|
||||
* @tparam BlockK_ Size of the K dimension
|
||||
* @tparam FragM_ Size of the M dimension
|
||||
* @tparam FragN_ Size of the N dimension
|
||||
* @tparam FragK_ Size of the K dimension
|
||||
* @tparam CtrlFlags_ Control flags for the MMA operation
|
||||
* @tparam CompilerTarget_ The compiler target
|
||||
*/
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
uint32_t BlockM_,
|
||||
uint32_t BlockN_,
|
||||
uint32_t BlockK_,
|
||||
uint32_t FragM_,
|
||||
uint32_t FragN_,
|
||||
uint32_t FragK_,
|
||||
typename CtrlFlags_,
|
||||
typename CompilerTarget_,
|
||||
MmaOpFamily OpFamily_>
|
||||
// TODO: c++20 amdgcn_target_arch_id CompilerTarget_>
|
||||
struct MmaOpParams<amdgcn_mma<ADataType_,
|
||||
struct MmaOpTraits<amdgcn_mma<ADataType_,
|
||||
BDataType_,
|
||||
CDataType_,
|
||||
BlockM_,
|
||||
BlockN_,
|
||||
BlockK_,
|
||||
FragM_,
|
||||
FragN_,
|
||||
FragK_,
|
||||
CtrlFlags_,
|
||||
CompilerTarget_,
|
||||
OpFamily_>>
|
||||
{
|
||||
// Capture incoming template parameters
|
||||
using ADataType = ADataType_;
|
||||
using BDataType = BDataType_;
|
||||
using CDataType = CDataType_;
|
||||
static constexpr uint32_t BlockM = BlockM_;
|
||||
static constexpr uint32_t BlockN = BlockN_;
|
||||
static constexpr uint32_t BlockK = BlockK_;
|
||||
using CtrlFlags = CtrlFlags_;
|
||||
using CompilerTarget = CompilerTarget_;
|
||||
static constexpr auto MmaOpFamily = OpFamily_;
|
||||
using MmaOp = amdgcn_mma<ADataType_,
|
||||
BDataType_,
|
||||
CDataType_,
|
||||
FragM_,
|
||||
FragN_,
|
||||
FragK_,
|
||||
CtrlFlags_,
|
||||
CompilerTarget_,
|
||||
OpFamily_>;
|
||||
|
||||
// Capture incoming template parameters not already in amdgcn
|
||||
using CtrlFlags = CtrlFlags_;
|
||||
using CompilerTarget = CompilerTarget_;
|
||||
// TODO c++20static constexpr amdgcn_target_arch_id GfxTargetId = CompilerTarget_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @class MmaOpTraits
|
||||
* @brief Reflects the template parameters and static members of a given MmaOp.
|
||||
* @tparam MmaOp The matrix multiply-accumulate operation
|
||||
*/
|
||||
template <typename MmaOp>
|
||||
// TODO: c++20 template <MmaOpI MmaOp>
|
||||
// TODO: c++20 requires MmaOpParamsI<MmaOpParams<MmaOp>>
|
||||
struct MmaOpTraits : public MmaOpParams<MmaOp>
|
||||
{
|
||||
// Capture internal MmaOp static members
|
||||
using OpType = typename MmaOp::OpType;
|
||||
using AVecType = typename MmaOp::AVecType;
|
||||
using BVecType = typename MmaOp::BVecType;
|
||||
using CVecType = typename MmaOp::CVecType;
|
||||
|
||||
static constexpr MmaOpFamily OpFamily = MmaOp::OpFamily;
|
||||
|
||||
// Capture layout parameters
|
||||
static constexpr index_t kAMBlock = MmaOp::kAMBlock;
|
||||
static constexpr index_t kBNBlock = MmaOp::kBNBlock;
|
||||
static constexpr index_t kAMLane = MmaOp::kAMLane;
|
||||
static constexpr index_t kBNLane = MmaOp::kBNLane;
|
||||
static constexpr index_t kABKLane = MmaOp::kABKLane;
|
||||
static constexpr index_t kABKPerLane = MmaOp::kABKPerLane;
|
||||
static constexpr index_t kCMLane = MmaOp::kCMLane;
|
||||
static constexpr index_t kCNLane = MmaOp::kCNLane;
|
||||
static constexpr index_t kCM0PerLane = MmaOp::kCM0PerLane;
|
||||
static constexpr index_t kCM1PerLane = MmaOp::kCM1PerLane;
|
||||
|
||||
// Additional traits to identify the type of MmaOp at compile time
|
||||
constexpr static bool IsMfma = is_mma_op_mfma_v<MmaOp>;
|
||||
constexpr static bool IsWmma = is_mma_op_wmma_v<MmaOp>;
|
||||
constexpr static bool IsDense = OpFamily == MmaOpFamily::DENSE;
|
||||
constexpr static bool IsSparse = OpFamily == MmaOpFamily::SPARSE;
|
||||
constexpr static bool IsScale = OpFamily == MmaOpFamily::SCALE;
|
||||
constexpr static bool IsDense = OpFamily_ == MmaOpFamily::DENSE;
|
||||
constexpr static bool IsSparse = OpFamily_ == MmaOpFamily::SPARSE;
|
||||
constexpr static bool IsScale = OpFamily_ == MmaOpFamily::SCALE;
|
||||
constexpr static bool IsSupported =
|
||||
is_mma_op_supported_v<MmaOp> && OpFamily != MmaOpFamily::UNDEFINED;
|
||||
is_mma_op_supported_v<MmaOp> && OpFamily_ != MmaOpFamily::UNDEFINED;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
|
||||
@@ -14,23 +14,24 @@ namespace ck_tile::core::arch::mma {
|
||||
/**
|
||||
* @class SparseMfmaDefaultSelector
|
||||
* @brief Implements a default sparse MFMA selector strategy. The SelectedOp can be unsupported.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam BlockM Size of the M dimension
|
||||
* @tparam BlockN Size of the N dimension
|
||||
* @tparam BlockKTest Size of the K dimension
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam WaveTileM Size of the M dimension
|
||||
* @tparam WaveTileN Size of the N dimension
|
||||
* @tparam WaveTileKTest Size of the K dimension
|
||||
* @tparam CompilerTarget The compiler target
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t BlockM,
|
||||
uint32_t BlockN,
|
||||
uint32_t BlockKTest,
|
||||
uint32_t WaveTileM,
|
||||
uint32_t WaveTileN,
|
||||
uint32_t WaveTileKTest,
|
||||
typename CompilerTarget>
|
||||
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
// TODO: c++20 requires(is_target_arch_cdna(CompilerTarget) && is_power_of_two_integer(BlockKTest))
|
||||
// TODO: c++20 requires(is_target_arch_cdna(CompilerTarget) &&
|
||||
// is_power_of_two_integer(WaveTileKTest))
|
||||
struct SparseMfmaDefaultSelector
|
||||
{
|
||||
private:
|
||||
@@ -38,26 +39,24 @@ struct SparseMfmaDefaultSelector
|
||||
using CandidateOp = amdgcn_mma<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
BlockKTest,
|
||||
WaveTileM,
|
||||
WaveTileN,
|
||||
WaveTileKTest,
|
||||
DefaultSparseMfmaCtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::SPARSE>;
|
||||
|
||||
using CandidateTraits = MmaOpTraits<CandidateOp>;
|
||||
|
||||
public:
|
||||
// If the candidate is supported (e.g., a backend implementation exists), then select it.
|
||||
// Otherwise, fall back to the unsupported pass-through implementation.
|
||||
using SelectedOp = std::conditional_t<CandidateTraits::IsSupported,
|
||||
using SelectedOp = std::conditional_t<MmaOpTraits<CandidateOp>::IsSupported,
|
||||
CandidateOp,
|
||||
amdgcn_mma<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
BlockKTest,
|
||||
WaveTileM,
|
||||
WaveTileN,
|
||||
WaveTileKTest,
|
||||
void,
|
||||
amdgcn_target<>,
|
||||
MmaOpFamily::UNDEFINED>>;
|
||||
@@ -67,21 +66,21 @@ struct SparseMfmaDefaultSelector
|
||||
* @struct MmaDefaultSelector
|
||||
* @brief Implements the CDNA default MMA selector strategy for sparse MFMA.
|
||||
* If no supported instruction is found, falls back to an unsupported pass-through implementation.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam FragM Size of the M dimension of the fragment to decompose
|
||||
* @tparam FragN Size of the N dimension of the fragment to decompose
|
||||
* @tparam FragK Size of the K dimension of the fragment to decompose
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam WaveTileM Size of the M dimension of the WaveTile to decompose
|
||||
* @tparam WaveTileN Size of the N dimension of the WaveTile to decompose
|
||||
* @tparam WaveTileK Size of the K dimension of the WaveTile to decompose
|
||||
* @tparam CompilerTarget The compiler target
|
||||
* @tparam OpFamily The MMA operation family
|
||||
* @tparam OpFamily The MMA operation family
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t FragM,
|
||||
uint32_t FragN,
|
||||
uint32_t FragK,
|
||||
uint32_t WaveTileM,
|
||||
uint32_t WaveTileN,
|
||||
uint32_t WaveTileK,
|
||||
typename CompilerTarget,
|
||||
MmaOpFamily OpFamily>
|
||||
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
@@ -89,9 +88,9 @@ template <typename ADataType,
|
||||
struct MmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
FragM,
|
||||
FragN,
|
||||
FragK,
|
||||
WaveTileM,
|
||||
WaveTileN,
|
||||
WaveTileK,
|
||||
CompilerTarget,
|
||||
OpFamily,
|
||||
enable_if_all<std::enable_if_t<is_any_value_of(CompilerTarget::TARGET_ID,
|
||||
@@ -126,23 +125,17 @@ struct MmaDefaultSelector<ADataType,
|
||||
1u,
|
||||
CompilerTarget>::SelectedOp;
|
||||
|
||||
// Traits for each candidate
|
||||
using CandidateTraits16x16 = MmaOpTraits<CandidateOp16x16>;
|
||||
using CandidateTraits32x32 = MmaOpTraits<CandidateOp32x32>;
|
||||
|
||||
// Check if each candidate is supported for the given fragment sizes
|
||||
// For this case, we require the fragment sizes to be multiples of the MFMA shape
|
||||
static constexpr bool IsSupported16x16 = CandidateTraits16x16::IsSupported &&
|
||||
(FragM % CandidateTraits16x16::BlockM == 0u) &&
|
||||
(FragN % CandidateTraits16x16::BlockN == 0u) &&
|
||||
(FragK % CandidateTraits16x16::BlockK == 0u);
|
||||
static constexpr bool IsSupported32x32 = CandidateTraits32x32::IsSupported &&
|
||||
(FragM % CandidateTraits32x32::BlockM == 0u) &&
|
||||
(FragN % CandidateTraits32x32::BlockN == 0u) &&
|
||||
(FragK % CandidateTraits32x32::BlockK == 0u);
|
||||
// Check if each candidate is supported for the given WaveTile sizes
|
||||
// For this case, we require the WaveTile sizes to be multiples of the MFMA shape
|
||||
static constexpr bool IsSupported16x16 =
|
||||
MmaOpTraits<CandidateOp16x16>::IsSupported && (WaveTileM % CandidateOp16x16::kM == 0u) &&
|
||||
(WaveTileN % CandidateOp16x16::kN == 0u) && (WaveTileK % CandidateOp16x16::kK == 0u);
|
||||
static constexpr bool IsSupported32x32 =
|
||||
MmaOpTraits<CandidateOp32x32>::IsSupported && (WaveTileM % CandidateOp32x32::kM == 0u) &&
|
||||
(WaveTileN % CandidateOp32x32::kN == 0u) && (WaveTileK % CandidateOp32x32::kK == 0u);
|
||||
|
||||
public:
|
||||
// Select the largest supported MFMA operation for the given fragment shape
|
||||
// Select the largest supported MFMA operation for the given WaveTile shape
|
||||
using SelectedOp =
|
||||
std::conditional_t<IsSupported32x32,
|
||||
CandidateOp32x32,
|
||||
|
||||
@@ -16,7 +16,7 @@ namespace ck_tile::core::arch::mma {
|
||||
* @brief Specialization of amdgcn_mma for Sparse MFMA (SMFMA) on GFX942, GFX950 targets
|
||||
*
|
||||
* This specialization implements the SMFMA instruction for fp16_t A and B
|
||||
* matrices with structured sparsity, fp32_t accumulator, with 16x16x32 block sizes.
|
||||
* matrices with structured sparsity, fp32_t accumulator, with 16x16x32 fragment sizes.
|
||||
*
|
||||
* @tparam CtrlFlags Control flags for the Sparse MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
@@ -24,53 +24,25 @@ namespace ck_tile::core::arch::mma {
|
||||
// TODO: c++20 template <CtrlFlagsSparseMfmaI CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
struct amdgcn_mma<
|
||||
fp16_t,
|
||||
fp16_t,
|
||||
fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
32u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::SPARSE,
|
||||
std::enable_if_t<is_any_value_of(
|
||||
CompilerTarget::TARGET_ID, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950)>>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, std::enable_if_t<is_any_value_of(CompilerTarget::TARGET_ID, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950)>>
|
||||
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, 64u, 8, 1, 1, 1, 1, 4, 1, MfmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
using OpType = MfmaOp;
|
||||
static constexpr MmaOpFamily OpFamily = MmaOpFamily::SPARSE;
|
||||
|
||||
static constexpr index_t ABVecN = 8;
|
||||
|
||||
using AVecType = ext_vector_t<fp16_t, ABVecN>;
|
||||
using BVecType = ext_vector_t<fp16_t, ABVecN>;
|
||||
using CVecType = ext_vector_t<fp32_t, 4>;
|
||||
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABKLane = 4;
|
||||
static constexpr index_t kABKPerLane = 8;
|
||||
|
||||
static constexpr index_t kCMLane = 4;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
static constexpr index_t kCM0PerLane = 1;
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
static constexpr index_t kCompressionRatio = 2;
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
|
||||
{
|
||||
static constexpr index_t CompressedSize = ABVecN / kCompressionRatio;
|
||||
using AVecCompressed = ext_vector_t<fp16_t, CompressedSize>;
|
||||
static constexpr index_t ABVecN = vector_traits<AVecType>::vector_size;
|
||||
static constexpr index_t kCompressionRatio = 2;
|
||||
static constexpr index_t CompressedSize = ABVecN / kCompressionRatio;
|
||||
using AVecCompressed = ext_vector_t<fp16_t, CompressedSize>;
|
||||
|
||||
static_assert(CompressedSize == 4);
|
||||
// TODO: Compressing A on-the-fly should be OK for now, but we need to validate
|
||||
// and evaluate changing this to a transform at a higher level.
|
||||
// aVec not being const can cause problems when running multiple intrinsics.
|
||||
const int32_t idx = ck_tile::compress_a_impl<fp16_t, CompressedSize>(aVec);
|
||||
const uint32_t idx = ck_tile::compress_a_impl<fp16_t, CompressedSize>(aVec);
|
||||
|
||||
const AVecCompressed a_vec_pruned = {aVec[0], aVec[1], aVec[2], aVec[3]};
|
||||
|
||||
|
||||
@@ -13,23 +13,24 @@ namespace ck_tile::core::arch::mma {
|
||||
/**
|
||||
* @class SparseWmmaDefaultSelector
|
||||
* @brief Implements a default sparse WMMA selector strategy. The SelectedOp can be unsupported.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam BlockM Size of the M dimension
|
||||
* @tparam BlockN Size of the N dimension
|
||||
* @tparam BlockKTest Size of the K dimension
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam WaveTileM Size of the M dimension
|
||||
* @tparam WaveTileN Size of the N dimension
|
||||
* @tparam WaveTileKTest Size of the K dimension
|
||||
* @tparam CompilerTarget The compiler target
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t BlockM,
|
||||
uint32_t BlockN,
|
||||
uint32_t BlockKTest,
|
||||
uint32_t WaveTileM,
|
||||
uint32_t WaveTileN,
|
||||
uint32_t WaveTileKTest,
|
||||
typename CompilerTarget>
|
||||
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
// TODO: c++20 requires(is_target_arch_rdna(CompilerTarget) && is_power_of_two_integer(BlockKTest))
|
||||
// TODO: c++20 requires(is_target_arch_rdna(CompilerTarget) &&
|
||||
// is_power_of_two_integer(WaveTileKTest))
|
||||
struct SparseWmmaDefaultSelector
|
||||
{
|
||||
private:
|
||||
@@ -37,9 +38,9 @@ struct SparseWmmaDefaultSelector
|
||||
using CandidateOp = amdgcn_mma<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
BlockKTest,
|
||||
WaveTileM,
|
||||
WaveTileN,
|
||||
WaveTileKTest,
|
||||
DefaultSparseWmmaCtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::SPARSE>;
|
||||
@@ -54,9 +55,9 @@ struct SparseWmmaDefaultSelector
|
||||
amdgcn_mma<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
BlockKTest,
|
||||
WaveTileM,
|
||||
WaveTileN,
|
||||
WaveTileKTest,
|
||||
void,
|
||||
amdgcn_target<>,
|
||||
MmaOpFamily::UNDEFINED>>;
|
||||
@@ -66,21 +67,21 @@ struct SparseWmmaDefaultSelector
|
||||
* @struct MmaDefaultSelector
|
||||
* @brief Implements the RDNA default MMA selector strategy for sparse WMMA.
|
||||
* If no supported instruction is found, falls back to an unsupported pass-through implementation.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam FragM Size of the M dimension of the fragment to decompose
|
||||
* @tparam FragN Size of the N dimension of the fragment to decompose
|
||||
* @tparam FragK Size of the K dimension of the fragment to decompose
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam WaveTileM Size of the M dimension of the WaveTile to decompose
|
||||
* @tparam WaveTileN Size of the N dimension of the WaveTile to decompose
|
||||
* @tparam WaveTileK Size of the K dimension of the WaveTile to decompose
|
||||
* @tparam CompilerTarget The compiler target
|
||||
* @tparam OpFamily The MMA operation family
|
||||
* @tparam OpFamily The MMA operation family
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t FragM,
|
||||
uint32_t FragN,
|
||||
uint32_t FragK,
|
||||
uint32_t WaveTileM,
|
||||
uint32_t WaveTileN,
|
||||
uint32_t WaveTileK,
|
||||
typename CompilerTarget,
|
||||
MmaOpFamily OpFamily>
|
||||
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
@@ -88,9 +89,9 @@ template <typename ADataType,
|
||||
struct MmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
FragM,
|
||||
FragN,
|
||||
FragK,
|
||||
WaveTileM,
|
||||
WaveTileN,
|
||||
WaveTileK,
|
||||
CompilerTarget,
|
||||
OpFamily,
|
||||
enable_if_all<enable_if_target_family_gfx12_t<CompilerTarget>,
|
||||
@@ -116,18 +117,14 @@ struct MmaDefaultSelector<ADataType,
|
||||
1u,
|
||||
CompilerTarget>::SelectedOp;
|
||||
|
||||
// Traits for each candidate
|
||||
using CandidateTraits16x16 = MmaOpTraits<CandidateOp16x16>;
|
||||
|
||||
// Check if each candidate is supported for the given fragment sizes
|
||||
// For this case, we require the fragment sizes to be multiples of the WMMA shape
|
||||
static constexpr bool IsSupported16x16 = CandidateTraits16x16::IsSupported &&
|
||||
(FragM % CandidateTraits16x16::BlockM == 0u) &&
|
||||
(FragN % CandidateTraits16x16::BlockN == 0u) &&
|
||||
(FragK % CandidateTraits16x16::BlockK == 0u);
|
||||
// Check if each candidate is supported for the given WaveTile sizes
|
||||
// For this case, we require the WaveTile sizes to be multiples of the WMMA shape
|
||||
static constexpr bool IsSupported16x16 =
|
||||
MmaOpTraits<CandidateOp16x16>::IsSupported && (WaveTileM % CandidateOp16x16::kM == 0u) &&
|
||||
(WaveTileN % CandidateOp16x16::kN == 0u) && (WaveTileK % CandidateOp16x16::kK == 0u);
|
||||
|
||||
public:
|
||||
// Select the largest supported WMMA operation for the given fragment shape
|
||||
// Select the largest supported WMMA operation for the given WaveTile shape
|
||||
using SelectedOp = std::conditional_t<IsSupported16x16, CandidateOp16x16, DefaultOp>;
|
||||
};
|
||||
|
||||
|
||||
@@ -15,51 +15,24 @@ namespace ck_tile::core::arch::mma {
|
||||
// TODO: c++20 template <CtrlFlagsSparseWmmaI CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
struct amdgcn_mma<fp16_t,
|
||||
fp16_t,
|
||||
fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
32u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::SPARSE,
|
||||
enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
// clang-format off
|
||||
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, 32u, 16, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
using OpType = WmmaOp;
|
||||
static constexpr MmaOpFamily OpFamily = MmaOpFamily::SPARSE;
|
||||
|
||||
static constexpr index_t ABVecN = 16;
|
||||
|
||||
using AVecType = ext_vector_t<fp16_t, ABVecN>;
|
||||
using BVecType = ext_vector_t<fp16_t, ABVecN>;
|
||||
using CVecType = ext_vector_t<fp32_t, 8>;
|
||||
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABKLane = 4;
|
||||
static constexpr index_t kABKPerLane = 8;
|
||||
|
||||
static constexpr index_t kCMLane = 4;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
static constexpr index_t kCM0PerLane = 1;
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
static constexpr index_t kCompressionRatio = 2;
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
|
||||
{
|
||||
static constexpr index_t CompressedSize = ABVecN / kCompressionRatio;
|
||||
using AVecCompressed = ext_vector_t<fp16_t, CompressedSize>;
|
||||
static constexpr index_t ABVecN = vector_traits<AVecType>::vector_size;
|
||||
static constexpr index_t kCompressionRatio = 2;
|
||||
static constexpr index_t CompressedSize = ABVecN / kCompressionRatio;
|
||||
using AVecCompressed = ext_vector_t<fp16_t, CompressedSize>;
|
||||
|
||||
static_assert(CompressedSize == 8);
|
||||
// TODO: Compressing A on-the-fly should be OK for now, but we need to validate
|
||||
// and evaluate changing this to a transform at a higher level.
|
||||
// aVec not being const can cause problems when running multiple intrinsics.
|
||||
const int32_t idx = ::ck_tile::compress_a_impl<fp16_t, CompressedSize>(aVec);
|
||||
const uint32_t idx = ck_tile::compress_a_impl<fp16_t, CompressedSize>(aVec);
|
||||
|
||||
const AVecCompressed a_vec_pruned = {
|
||||
aVec[0], aVec[1], aVec[2], aVec[3], aVec[4], aVec[5], aVec[6], aVec[7]};
|
||||
|
||||
@@ -70,38 +70,12 @@ struct DefaultWmmaCtrlFlags
|
||||
// TODO: c++20 template <CtrlFlagsGfx11I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
struct amdgcn_mma<fp16_t,
|
||||
fp16_t,
|
||||
fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::DENSE,
|
||||
std::enable_if_t<is_target_family_gfx11<CompilerTarget>()>>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, std::enable_if_t<is_target_family_gfx11<CompilerTarget>()>>
|
||||
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, 32u, 16, 1, 2, 1, 2, 8, 8, WmmaOp, MmaOpFamily::DENSE>
|
||||
// clang-format on
|
||||
{
|
||||
// Wmma operation type
|
||||
using OpType = WmmaOp;
|
||||
static constexpr MmaOpFamily OpFamily = MmaOpFamily::DENSE;
|
||||
|
||||
// Register types (duplicated input / b32 accum)
|
||||
using AVecType = ext_vector_t<fp16_t, 16>;
|
||||
using BVecType = ext_vector_t<fp16_t, 16>;
|
||||
using CVecType = ext_vector_t<fp32_t, 8>;
|
||||
|
||||
// Layout constants
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABKLane = 8;
|
||||
static constexpr index_t kABKPerLane = 8;
|
||||
static constexpr index_t kCMLane = 2;
|
||||
static constexpr index_t kCNLane = 2;
|
||||
static constexpr index_t kCM0PerLane = 4;
|
||||
static constexpr index_t kCM1PerLane = 1;
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
|
||||
{
|
||||
|
||||
@@ -30,38 +30,12 @@ namespace ck_tile::core::arch::mma {
|
||||
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
struct amdgcn_mma<fp16_t,
|
||||
fp16_t,
|
||||
fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::DENSE,
|
||||
enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, 32u, 8, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::DENSE>
|
||||
// clang-format on
|
||||
{
|
||||
// Wmma operation type
|
||||
using OpType = WmmaOp;
|
||||
static constexpr MmaOpFamily OpFamily = MmaOpFamily::DENSE;
|
||||
|
||||
// Register types
|
||||
using AVecType = ext_vector_t<fp16_t, 8>;
|
||||
using BVecType = ext_vector_t<fp16_t, 8>;
|
||||
using CVecType = ext_vector_t<fp32_t, 8>;
|
||||
|
||||
// Layout constants
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABKLane = 8;
|
||||
static constexpr index_t kABKPerLane = 8;
|
||||
static constexpr index_t kCMLane = 2;
|
||||
static constexpr index_t kCNLane = 2;
|
||||
static constexpr index_t kCM0PerLane = 4;
|
||||
static constexpr index_t kCM1PerLane = 1;
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
|
||||
{
|
||||
|
||||
@@ -14,24 +14,24 @@ namespace ck_tile::core::arch::mma {
|
||||
* @class WmmaDefaultSelector
|
||||
* @brief Implements a default WMMA selector strategy for gfx11/12 target architectures.
|
||||
* This implements the K dimension search strategy to find the largest supported WMMA
|
||||
* instruction for the given M/N block sizes and datatypes.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam BlockM Size of the M dimension
|
||||
* @tparam BlockN Size of the N dimension
|
||||
* @tparam BlockKTest Size of the K dimension
|
||||
* instruction for the given M/N WaveTile sizes and datatypes.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam WaveTileM Size of the M dimension
|
||||
* @tparam WaveTileN Size of the N dimension
|
||||
* @tparam WaveTileKTest Size of the K dimension
|
||||
* @tparam CompilerTarget The compiler target
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t BlockM,
|
||||
uint32_t BlockN,
|
||||
uint32_t BlockKTest,
|
||||
uint32_t WaveTileM,
|
||||
uint32_t WaveTileN,
|
||||
uint32_t WaveTileKTest,
|
||||
typename CompilerTarget>
|
||||
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
// TODO: c++20 requires(is_rdna_arch_id(CompilerTarget) && is_power_of_two_integer(BlockKTest))
|
||||
// TODO: c++20 requires(is_rdna_arch_id(CompilerTarget) && is_power_of_two_integer(WaveTileKTest))
|
||||
struct WmmaDefaultSelector
|
||||
{
|
||||
private:
|
||||
@@ -42,27 +42,25 @@ struct WmmaDefaultSelector
|
||||
using CandidateOp = amdgcn_mma<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
BlockKTest,
|
||||
WaveTileM,
|
||||
WaveTileN,
|
||||
WaveTileKTest,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::DENSE>;
|
||||
|
||||
using CandidateTraits = MmaOpTraits<CandidateOp>;
|
||||
|
||||
public:
|
||||
// If the candidate is supported (e.g., a backend implementation exists), then select it.
|
||||
// Otherwise, test another smaller BlockK. If no existing implementations, we will get BlockK=0u
|
||||
// and fall back to the unsupported pass-through implementation.
|
||||
using SelectedOp = std::conditional_t<CandidateTraits::IsSupported,
|
||||
// Otherwise, test another smaller WaveTileK. If no existing implementations, we will get
|
||||
// WaveTileK=0u and fall back to the unsupported pass-through implementation.
|
||||
using SelectedOp = std::conditional_t<MmaOpTraits<CandidateOp>::IsSupported,
|
||||
CandidateOp,
|
||||
typename WmmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
BlockKTest / 2u,
|
||||
WaveTileM,
|
||||
WaveTileN,
|
||||
WaveTileKTest / 2u,
|
||||
CompilerTarget>::SelectedOp>;
|
||||
};
|
||||
|
||||
@@ -72,21 +70,27 @@ struct WmmaDefaultSelector
|
||||
* This implements the K dimension == 1, which is the base case for the recursive K dimension
|
||||
* search. If no supported instruction is found, falls back to an unsupported pass-through
|
||||
* implementation.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam BlockM Size of the M dimension
|
||||
* @tparam BlockN Size of the N dimension
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam WaveTileM Size of the M dimension
|
||||
* @tparam WaveTileN Size of the N dimension
|
||||
* @tparam CompilerTarget The compiler target
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t BlockM,
|
||||
uint32_t BlockN,
|
||||
uint32_t WaveTileM,
|
||||
uint32_t WaveTileN,
|
||||
typename CompilerTarget>
|
||||
// TODO: c++20 amdgcn_target_arch_id GfxTargetId>
|
||||
struct WmmaDefaultSelector<ADataType, BDataType, CDataType, BlockM, BlockN, 1u, CompilerTarget>
|
||||
struct WmmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
WaveTileM,
|
||||
WaveTileN,
|
||||
1u,
|
||||
CompilerTarget>
|
||||
{
|
||||
// By default, let's assume no special flags for WMMA
|
||||
using CtrlFlags = DefaultWmmaCtrlFlags<ADataType, BDataType, CDataType>;
|
||||
@@ -95,8 +99,8 @@ struct WmmaDefaultSelector<ADataType, BDataType, CDataType, BlockM, BlockN, 1u,
|
||||
using SelectedOp = amdgcn_mma<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
WaveTileM,
|
||||
WaveTileN,
|
||||
1u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
@@ -106,24 +110,24 @@ struct WmmaDefaultSelector<ADataType, BDataType, CDataType, BlockM, BlockN, 1u,
|
||||
/**
|
||||
* @struct MmaDefaultSelector
|
||||
* @brief Implements the rdna default MMA selector strategy for wave-wise MMA decomposition.
|
||||
* This implements the M/N block size search strategy to find the largest supported WMMA
|
||||
* This implements the M/N WaveTile size search strategy to find the largest supported WMMA
|
||||
* instruction for the given datatypes.
|
||||
* If no supported instruction is found, falls back to an unsupported pass-through implementation.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam FragM Size of the M dimension of the fragment to decompose
|
||||
* @tparam FragN Size of the N dimension of the fragment to decompose
|
||||
* @tparam FragK Size of the K dimension of the fragment to decompose
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam WaveTileM Size of the M dimension of the WaveTile to decompose
|
||||
* @tparam WaveTileN Size of the N dimension of the WaveTile to decompose
|
||||
* @tparam WaveTileK Size of the K dimension of the WaveTile to decompose
|
||||
* @tparam CompilerTarget The compiler target
|
||||
* @tparam OpFamily The MMA operation family
|
||||
* @tparam OpFamily The MMA operation family
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t FragM,
|
||||
uint32_t FragN,
|
||||
uint32_t FragK,
|
||||
uint32_t WaveTileM,
|
||||
uint32_t WaveTileN,
|
||||
uint32_t WaveTileK,
|
||||
typename CompilerTarget,
|
||||
MmaOpFamily OpFamily>
|
||||
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
@@ -131,9 +135,9 @@ template <typename ADataType,
|
||||
struct MmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
FragM,
|
||||
FragN,
|
||||
FragK,
|
||||
WaveTileM,
|
||||
WaveTileN,
|
||||
WaveTileK,
|
||||
CompilerTarget,
|
||||
OpFamily,
|
||||
enable_if_all<enable_if_target_arch_rdna_t<CompilerTarget>,
|
||||
@@ -155,18 +159,14 @@ struct MmaDefaultSelector<ADataType,
|
||||
typename WmmaDefaultSelector<ADataType, BDataType, CDataType, 1u, 1u, 1u, CompilerTarget>::
|
||||
SelectedOp;
|
||||
|
||||
// Traits for each candidate
|
||||
using CandidateTraits16x16 = MmaOpTraits<CandidateOp16x16>;
|
||||
|
||||
// Check if each candidate is supported for the given fragment sizes
|
||||
// For this case, we require the fragment sizes to be multiples of the WMMA shape
|
||||
static constexpr bool IsSupported16x16 = CandidateTraits16x16::IsSupported &&
|
||||
(FragM % CandidateTraits16x16::BlockM == 0u) &&
|
||||
(FragN % CandidateTraits16x16::BlockN == 0u) &&
|
||||
(FragK % CandidateTraits16x16::BlockK == 0u);
|
||||
// Check if each candidate is supported for the given WaveTile sizes
|
||||
// For this case, we require the WaveTile sizes to be multiples of the WMMA shape
|
||||
static constexpr bool IsSupported16x16 =
|
||||
MmaOpTraits<CandidateOp16x16>::IsSupported && (WaveTileM % CandidateOp16x16::kM == 0u) &&
|
||||
(WaveTileN % CandidateOp16x16::kN == 0u) && (WaveTileK % CandidateOp16x16::kK == 0u);
|
||||
|
||||
public:
|
||||
// Select the largest supported WMMA operation for the given fragment shape
|
||||
// Select the largest supported WMMA operation for the given WaveTile shape
|
||||
using SelectedOp = std::conditional_t<IsSupported16x16, CandidateOp16x16, DefaultOp>;
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user