diff --git a/include/ck_tile/core/arch/mma/amdgcn_mma.hpp b/include/ck_tile/core/arch/mma/amdgcn_mma.hpp index f9f9d7ca37..63148faf99 100644 --- a/include/ck_tile/core/arch/mma/amdgcn_mma.hpp +++ b/include/ck_tile/core/arch/mma/amdgcn_mma.hpp @@ -14,6 +14,160 @@ namespace ck_tile::core::arch::mma { +/**--------------------------------------------------- + * Meaning of amdgcn_mma layout parameters (general) + * --------------------------------------------------- + * + * The fragment (MmaTile) sizes and layout constants in the amdgcn_mma struct describe the mapping + * between intrinsic input / output matrix elements and vector registers (lane x vector_item space). + * Note that we end up having a mapping for A, B and C separately, although those for A and B are + * usually similar if not identical. All mappings can be described as an unmerge operation on one of + * the matrix dims (either K for AB or M for C), followed by remerging of the resulting subdims and + * raw other dim into the Lane and Vector_item dimensions. When considering an unmerge operation on + * a dimension K, we can label the resulting sub-dimensions as K0, K1, and K2, where K0 is the size + * of the fastest changing dimension. K0 is also referred to as "The size of the first unmerge", and + * K1 would be "The size of the second unmerge". There are never more than 2 unmerge operations, and + * unmerge operations may be trivial (unmerge size of 1). Example double unmerge of size {3, 2} of a + * K dimension of size 12: + * + * K K2 K1 K0 + * 0 0 0 0 + * 1 0 0 1 + * 2 0 1 0 + * 3 0 1 1 + * 4 0 2 0 + * 5 0 2 1 + * 6 1 0 0 + * 7 1 0 1 + * 8 1 1 0 + * 9 1 1 1 + * 10 1 2 0 + * 11 1 2 1 + * + * Note that K0 = 2 (first unmerge size, fastest changing), K1 = 3 (second unmerge size, + * second-fastest changing), and K2 = 12 / 2 / 3 = 2 (outermost dimension, whatever is left). + * + * If we were to use this unmerge op to describe an A matrix layout in registers, we might have for + * example that L (lane dim) is composed of K1 and M, and V (vector_item dim) is composed of K2 and + * K0. Compactly described, this would be K{3, 2} L{K1M} V{K2K0}, and if the M dimension was 2 we + * would have the following layout (6 lanes, 4 vector items each): + * + * | V0 | V1 | V2 | V3 | + * L0 | M=0 K=0 | M=0 K=1 | M=0 K=6 | M=0 K=7 | + * L1 | M=1 K=0 | M=1 K=1 | M=1 K=6 | M=1 K=7 | + * L2 | M=0 K=2 | M=0 K=3 | M=0 K=8 | M=0 K=9 | + * L3 | M=1 K=2 | M=1 K=3 | M=1 K=8 | M=1 K=9 | + * L4 | M=0 K=4 | M=0 K=5 | M=0 K=10 | M=0 K=11 | + * L5 | M=1 K=4 | M=1 K=5 | M=1 K=10 | M=1 K=11 | + * + * Note that all A matrix elements are now placed in a unique (lane, vector_item). In case a Repeat + * dimension is used, every single matrix element is mapped to multiple (Lane, vector_item) + * locations, usually along the Lane dimension. + * + * Check out TileDistrEncRegMap which can print full forward and backward mapping tables for any + * register mapping (expressed as a tile distribution encoding). + * + * ------------------------------------------ + * Individual amdgcn_mma layout parameters + * ------------------------------------------ + * + * -- ABKPerLane -- + * The number of K dim elements in each lane. Always the same for A and B, even when they have + * different layouts. In terms of unmerge sizes, it's equal to K0 * K2, i.e the product of the sizes + * of the outermost and innermost dimensions after a double K unmerge. + * + * -- A / B NumAccess -- + * These two variables describe the size of the outermost dimension if two unmerge operations are + * required for K (so K2). Alternatively it can be described as the number of sets the vector + * dimension, which houses a number of K indices, is split up into. We may be able to actually + * remove the A / B NumAccess from the amdgcn struct, but it sort of depends on how load and store + * tile work and whether we want the mid-level code to always have to know about this. There are + * only two reasons for the A / B NumAccess to ever not be 1, and they are different types of + * reasons: + * + * (logical correctness). Applies to scale MFMA fp8, which due to the index matrix layout does not + * allow arbitrary K perms to simplify layouts. This means the layout can only properly be described + * with a Num Access value of at least 2. + * + * (load / store manipulation). It seems like the load and store tile functions end up looking for + * the size of the smallest unmerged K dimension (K0) to determine how many elements should be + * loaded at a time. Different Num Access values will lead to different load / store behavior, even + * if logically equivalent. + * + * -- A / B Repeat -- + * Variable indicating that all matrix values are represented multiple times in the vector + * registers, typically repeating in the lane dimension. This is always equal to the repeat value + * used in Tile Distribution encodings. There are two reasons to have non-trivial (non-1) value + * here: MFMA block-hiding to create oblong "virtual" intrinsics, and RDNA3 input repetition. + * + * -- CMPerLane -- + * The number of M dim elements in each lane. In terms of unmerge sizes, it's equal to M0 * M2, i.e + * the product of the sizes of the outermost and innermost dimensions after a double M unmerge. + * + * -- CNumAccess -- + * Same as A / B NumAccess but for the M dim (so M2), but the mid-level code doesn't care about this + * and will not try to request a specific value. Absolutely needed for logical correctness of + * register mappings since we can not perform arbitrary M permutations without messing up the A + * layout. + */ + +/** + * @class amdgcn_mma_base + * @brief Base class for amdgcn_mma structs to avoid a lot of code duplication. Also puts + * all generic parameter derivations and static asserts in one place. Houses all of the + * amdgcn struct types and variables, except for the exec() function. + */ +template +struct amdgcn_mma_base +{ + using OpType = OpType_; + static constexpr MmaOpFamily OpFamily = OpFamily_; + + // Data types + using ADataType = ADataType_; + using BDataType = BDataType_; + using CDataType = CDataType_; + + // Fragment (MmaTile) sizes, check description above. + static constexpr index_t kM = FragM; // M = M2 * M1 * M0 + static constexpr index_t kN = FragN; + static constexpr index_t kK = FragK; // K = K2 * K1 * K0 + + // Layout constants, check description above. + static constexpr index_t kABKPerLane = kABKPerLane_; // K2 * K0 + static constexpr index_t kAKNumAccess = kAKNumAccess_; // K2 + static constexpr index_t kARepeat = kARepeat_; // RDNA3 repetition and MFMA block-hiding + static constexpr index_t kBKNumAccess = kBKNumAccess_; // K2 + static constexpr index_t kBRepeat = kBRepeat_; // RDNA3 repetition and MFMA block-hiding + static constexpr index_t kCMPerLane = kCMPerLane_; // M2 * M0 + static constexpr index_t kCMNumAccess = kCMNumAccess_; // M2 + + // Register types (derived) + static constexpr index_t WaveSize = WaveSize_; + static_assert((kM * kK * kARepeat) % WaveSize == 0); + static_assert((kN * kK * kBRepeat) % WaveSize == 0); + static_assert((kM * kN) % WaveSize == 0); + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; +}; + /** * @struct Unsupported * @brief Meta-tag to indicate unsupported amdgcn_mma instance. @@ -31,23 +185,24 @@ template concept MmaOpI = requires(MmaOp op) { // Requires an op context typename MmaOp::OpType; + typename MmaOp::OpFamily; // Captures types for inputs / outputs to mma function + typename MmaOp::ADataType; + typename MmaOp::BDataType; + typename MmaOp::CDataType; typename MmaOp::AVecType; typename MmaOp::BVecType; typename MmaOp::CVecType; // Captures CK-specific layout properties - { MmaOp::kAMBlock } -> std::convertible_to; - { MmaOp::kBNBlock } -> std::convertible_to; - { MmaOp::kAMLane } -> std::convertible_to; - { MmaOp::kBNLane } -> std::convertible_to; - { MmaOp::kABKLane } -> std::convertible_to; { MmaOp::kABKPerLane } -> std::convertible_to; - { MmaOp::kCMLane } -> std::convertible_to; - { MmaOp::kCNLane } -> std::convertible_to; - { MmaOp::kCM0PerLane } -> std::convertible_to; - { MmaOp::kCM1PerLane } -> std::convertible_to; + { MmaOp::kAKNumAccess } -> std::convertible_to; + { MmaOp::kARepeat } -> std::convertible_to; + { MmaOp::kBKNumAccess } -> std::convertible_to; + { MmaOp::kBRepeat } -> std::convertible_to; + { MmaOp::kCMPerLane } -> std::convertible_to; + { MmaOp::kCMNumAccess } -> std::convertible_to; // Static exec function { @@ -69,52 +224,40 @@ concept MmaOpI = requires(MmaOp op) { * @tparam ADataType Datatype of input A * @tparam BDataType Datatype of input B * @tparam CDataType Datatype of accumulator - * @tparam BlockM M-dimension of mma block - * @tparam BlockN N-dimension of mma block - * @tparam BlockK K-dimension of mma block + * @tparam FragM M-dimension of mma intrinsic (MmaTile) + * @tparam FragN N-dimension of mma intrinsic (MmaTile) + * @tparam FragK K-dimension of mma intrinsic (MmaTile) * @tparam CtrlFlags Control flags for mma operation * @tparam CompilerTarget The current compiler target + * @tparam OpFamily_ The type of operation (dense, sparse, scale, etc.) * @tparam Enabler SFINAE enabler */ template -struct amdgcn_mma +// clang-format off +// | A B C DataTypes |MNK + WaveSize |AParams |BPar |CPar | +struct amdgcn_mma : amdgcn_mma_base +// clang-format on { - // The base instance is unsupported because there is no __builtin to wrap. - using OpType = Unsupported; - static constexpr MmaOpFamily OpFamily = MmaOpFamily::UNDEFINED; - - // Interface types for A, B, C vectors types - using AVecType = ext_vector_t; - using BVecType = ext_vector_t; - using CVecType = ext_vector_t; - - // Layout constants - default to 0 - static constexpr index_t kAMBlock = 0; - static constexpr index_t kBNBlock = 0; - - static constexpr index_t kAMLane = 0; - static constexpr index_t kBNLane = 0; - static constexpr index_t kABKLane = 0; - static constexpr index_t kABKPerLane = 0; - - static constexpr index_t kCMLane = 0; - static constexpr index_t kCNLane = 0; - static constexpr index_t kCM0PerLane = 0; - static constexpr index_t kCM1PerLane = 0; - // This is a default pass-through implementation that doesn't do anything practical. CK_TILE_DEVICE static CVecType const& exec(AVecType const& regsA, BVecType const& regsB, CVecType const& regsC) { + // Prints once across all thread blocks and threads. + static __device__ int printed = 0; + if(threadIdx.x == 0 && atomicCAS(&printed, 0, 1) == 0) + { + printf("[WARNING] Running amdgcn_mma dummy exec function!\n"); + } + ignore(regsA, regsB); return regsC; // No-op, just return C } diff --git a/include/ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp b/include/ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp index 1d1267a839..4955e2bf7f 100644 --- a/include/ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp +++ b/include/ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp @@ -25,7 +25,7 @@ namespace ck_tile::core::arch::mma { * @brief Specialization of amdgcn_mma for MFMA on GFX9 targets * * This specialization implements the MFMA instruction for fp16_t A and B - * matrices, and fp32_t accumulator matrix, with 16x16x16 block sizes. + * matrices, and fp32_t accumulator matrix, with 16x16x16 fragment sizes. * * @tparam CtrlFlags Control flags for the MFMA operation * @tparam CompilerTarget Current compiler target @@ -33,40 +33,12 @@ namespace ck_tile::core::arch::mma { // TODO: c++20 template // TODO: c++20 requires template -struct amdgcn_mma> +// clang-format off +// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | +struct amdgcn_mma> +: amdgcn_mma_base +// clang-format on { - // Mfma operation type - using OpType = MfmaOp; - static constexpr MmaOpFamily OpFamily = MmaOpFamily::DENSE; - - // Register types - using AVecType = ext_vector_t; - using BVecType = ext_vector_t; - using CVecType = ext_vector_t; - - // Layout constants - static constexpr index_t kAMBlock = 1; - static constexpr index_t kBNBlock = 1; - - static constexpr index_t kAMLane = 16; - static constexpr index_t kBNLane = 16; - static constexpr index_t kABKLane = 4; - static constexpr index_t kABKPerLane = 4; - - static constexpr index_t kCMLane = 4; - static constexpr index_t kCNLane = 16; - static constexpr index_t kCM0PerLane = 1; - static constexpr index_t kCM1PerLane = 4; - CK_TILE_DEVICE static auto exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType { @@ -84,7 +56,7 @@ struct amdgcn_mma // TODO: c++20 requires template -struct amdgcn_mma> +// clang-format off +// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | +struct amdgcn_mma> +: amdgcn_mma_base +// clang-format on { - using OpType = MfmaOp; - static constexpr MmaOpFamily OpFamily = MmaOpFamily::DENSE; - - // Packed register types - using AVecType = ext_vector_t; - using BVecType = ext_vector_t; - using CVecType = ext_vector_t; - - // Layout constants - static constexpr index_t kAMBlock = 1; - static constexpr index_t kBNBlock = 1; - - static constexpr index_t kAMLane = 16; - static constexpr index_t kBNLane = 16; - static constexpr index_t kABKLane = 8; - static constexpr index_t kABKPerLane = 8; - - static constexpr index_t kCMLane = 4; - static constexpr index_t kCNLane = 16; - static constexpr index_t kCM0PerLane = 1; - static constexpr index_t kCM1PerLane = 4; - CK_TILE_DEVICE static auto exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType { diff --git a/include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp b/include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp index 051b9d30ff..2140e3317a 100644 --- a/include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp +++ b/include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp @@ -18,28 +18,27 @@ namespace ck_tile::core::arch::mma { * @class MfmaDefaultSelector * @brief Implements a default MFMA selector strategy for gfx9 target architectures. * This implements the K dimension search strategy to find the largest supported MFMA - * instruction for the given M/N block sizes and datatypes. - * If no supported instruction is found, falls back to an unsupported pass-through - implementation. - * @tparam ADataType Data type of matrix A - * @tparam BDataType Data type of matrix B - * @tparam CDataType Data type of the accumulator - * @tparam BlockM Block M dimension size - * @tparam BlockN Block N dimension size - * @tparam BlockKTest Current Block K dimension size to test + * instruction for the given M/N WaveTile sizes and datatypes. + * If no supported instruction is found, falls back to an unsupported pass-through implementation. + * @tparam ADataType Data type of matrix A + * @tparam BDataType Data type of matrix B + * @tparam CDataType Data type of the accumulator + * @tparam WaveTileM WaveTile M dimension size + * @tparam WaveTileN WaveTile N dimension size + * @tparam WaveTileKTest Current WaveTile K dimension size to test * @tparam CompilerTarget The compiler target - * @note Here we assume that BlockKTest is always a power-of-two integer. - * The search strategy starts from a maximum BlockKTest size down to 1u by halving + * @note Here we assume that WaveTileKTest is always a power-of-two integer. + * The search strategy starts from a maximum WaveTileKTest size down to 1u by halving * each time. */ template // TODO: c++20 amdgcn_target_arch_id CompilerTarget> -// TODO: c++20 requires(is_gfx9_arch_id(CompilerTarget) && is_power_of_two_integer(BlockKTest)) +// TODO: c++20 requires(is_gfx9_arch_id(CompilerTarget) && is_power_of_two_integer(WaveTileKTest)) struct MfmaDefaultSelector { private: @@ -48,26 +47,25 @@ struct MfmaDefaultSelector amdgcn_mma; - using CandidateTraits = MmaOpTraits; public: // If the candidate is supported (e.g., a backend implementation exists), then select it. - // Otherwise, test another smaller BlockK. If no existing implementations, we will get BlockK=0u - // and fall back to the unsupported pass-through implementation. - using SelectedOp = std::conditional_t::IsSupported, CandidateOp, typename MfmaDefaultSelector::SelectedOp>; }; @@ -75,28 +73,34 @@ struct MfmaDefaultSelector * @struct MfmaDefaultSelector * @brief Implements the base case for the default MFMA selector when no supported instruction is * found. - * @tparam ADataType Data type of matrix A - * @tparam BDataType Data type of matrix B - * @tparam CDataType Data type of the accumulator - * @tparam BlockM Block M dimension size - * @tparam BlockN Block N dimension size + * @tparam ADataType Data type of matrix A + * @tparam BDataType Data type of matrix B + * @tparam CDataType Data type of the accumulator + * @tparam WaveTileM WaveTile M dimension size + * @tparam WaveTileN WaveTile N dimension size * @tparam CompilerTarget The compiler target */ template // TODO: c++20 amdgcn_target_arch_id CompilerTarget> -struct MfmaDefaultSelector +struct MfmaDefaultSelector { // Default unsupported pass-through if no instruction is found using SelectedOp = amdgcn_mma // TODO: c++20 amdgcn_target_arch_id CompilerTarget> struct MmaDefaultSelector, @@ -163,27 +167,20 @@ struct MmaDefaultSelector:: SelectedOp; - // Traits for each candidate - using CandidateTraits4x4 = MmaOpTraits; - using CandidateTraits16x16 = MmaOpTraits; - using CandidateTraits32x32 = MmaOpTraits; - - // Check if each candidate is supported for the given fragment sizes - // For this case, we require the fragment sizes to be multiples of the MFMA shape + // Check if each candidate is supported for the given WaveTile sizes + // For this case, we require the WaveTile sizes to be multiples of the MFMA shape static constexpr bool IsSupported4x4 = - CandidateTraits4x4::IsSupported && (FragM % CandidateTraits4x4::BlockM == 0u) && - (FragN % CandidateTraits4x4::BlockN == 0u) && (FragK % CandidateTraits4x4::BlockK == 0u); - static constexpr bool IsSupported16x16 = CandidateTraits16x16::IsSupported && - (FragM % CandidateTraits16x16::BlockM == 0u) && - (FragN % CandidateTraits16x16::BlockN == 0u) && - (FragK % CandidateTraits16x16::BlockK == 0u); - static constexpr bool IsSupported32x32 = CandidateTraits32x32::IsSupported && - (FragM % CandidateTraits32x32::BlockM == 0u) && - (FragN % CandidateTraits32x32::BlockN == 0u) && - (FragK % CandidateTraits32x32::BlockK == 0u); + MmaOpTraits::IsSupported && (WaveTileM % CandidateOp4x4::kM == 0u) && + (WaveTileN % CandidateOp4x4::kN == 0u) && (WaveTileK % CandidateOp4x4::kK == 0u); + static constexpr bool IsSupported16x16 = + MmaOpTraits::IsSupported && (WaveTileM % CandidateOp16x16::kM == 0u) && + (WaveTileN % CandidateOp16x16::kN == 0u) && (WaveTileK % CandidateOp16x16::kK == 0u); + static constexpr bool IsSupported32x32 = + MmaOpTraits::IsSupported && (WaveTileM % CandidateOp32x32::kM == 0u) && + (WaveTileN % CandidateOp32x32::kN == 0u) && (WaveTileK % CandidateOp32x32::kK == 0u); public: - // Select the largest supported MFMA operation for the given fragment shape + // Select the largest supported MFMA operation for the given WaveTile shape using SelectedOp = std::conditional_t< IsSupported32x32, CandidateOp32x32, diff --git a/include/ck_tile/core/arch/mma/mma.hpp b/include/ck_tile/core/arch/mma/mma.hpp index 9b38ff9b18..b0eb507b49 100644 --- a/include/ck_tile/core/arch/mma/mma.hpp +++ b/include/ck_tile/core/arch/mma/mma.hpp @@ -6,7 +6,6 @@ #include "amdgcn_mma.hpp" #include "mma_selector.hpp" -#include "mma_traits.hpp" #include "mma_transforms.hpp" #include "mfma/mfma.hpp" @@ -19,44 +18,42 @@ namespace ck_tile::core::arch::mma { */ enum struct MmaAccumPolicy { - // Decomposition and accumulation in row-major block order + // Decomposition and accumulation in row-major fragment order ROW_MAJOR, - // Decomposition and accumulation in col-major block order + // Decomposition and accumulation in col-major fragment order COL_MAJOR }; /** * @class Mma - * @brief Driver for the wave-tile Mma operation. Given a backend block-wise MmaOp implementation - * (e.g., mfma or wmma), this class performs block-wise decomposition to matrix-multiply input - * fragments of (A: FragM x FragK) x (B: FragK x FragN) and accumulates results into output fragment - * (C: FragM x FragN). - * @tparam ADataType Data type of input fragment A - * @tparam BDataType Data type of input fragment B - * @tparam CDataType Data type of input/output fragment C (accumulator) - * @tparam FragM Mma fragment M dimension - * @tparam FragN Mma fragment K dimension - * @tparam FragK Mma fragment M dimension - * @tparam AccumPolicy The block order of the accumulation registers (row major or col major block - * order) + * @brief Driver for the wave-tile Mma operation. Given a backend MmaOp implementation + * (e.g., mfma or wmma), this class performs fragment-wise (MmaTile) decomposition to + * matrix-multiply input WaveTiles of (A: WaveTileM x WaveTileK) x (B: WaveTileK x WaveTileN) and + * accumulates results into output WaveTile (C: WaveTileM x WaveTileN). + * @tparam ADataType Data type of input WaveTile A + * @tparam BDataType Data type of input WaveTile B + * @tparam CDataType Data type of input/output WaveTile C (accumulator) + * @tparam WaveTileM Mma WaveTile M dimension + * @tparam WaveTileN Mma WaveTile K dimension + * @tparam WaveTileK Mma WaveTile M dimension + * @tparam AccumPolicy The fragment order of the accum. registers (row or col major frag order) * @tparam CompilerTarget The compiler target - * @tparam MmaOp The backend wrapper class that will perform block-wise mma op (e.g., mfma or - * wmma) - * @tparam MmaTransforms The set of transforms to be applied to input/output fragments + * @tparam MmaOp Backend wrapper class that will perform the mma op (e.g., mfma or wmma) + * @tparam MmaTransforms The set of transforms to be applied to input/output WaveTiles * @par This is an example of an Mma decomposition driver class that can be used in a wave-tile - * context. Given a fragment size, we can decompose the fragment into smaller block-wise mma ops + * context. Given a WaveTile size, we can decompose the WaveTile into smaller mma op fragments * that are natively supported by the hardware (e.g., mfma or wmma). The class also supports - * applying transforms to the input/output fragments as needed (e.g., layout conversions, data type + * applying transforms to the input/output frags as needed (e.g., layout conversions, data type * conversions, etc.). We may also specify the accumulation order (row-major or col-major) for the - * output fragment. This is a powerful example of how to build a flexible and reusable mma driver + * output WaveTile. This is a powerful example of how to build a flexible and reusable mma driver * that can adapt to different hardware capabilities and requirements. */ template ::SelectedOp, typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms = typename MmaTransformsDefaultSelector::SelectedTransforms> struct WaveWiseMma { + using FragWiseMmaOp = MmaOp; - using BlockWiseMmaOp = MmaOp; - using BlockWiseMmaOpTraits = MmaOpTraits; + // Fragment dimensions + constexpr static uint32_t FragM = MmaOp::kM; + constexpr static uint32_t FragN = MmaOp::kN; + constexpr static uint32_t FragK = MmaOp::kK; - // Block dimensions - constexpr static uint32_t BlockM = BlockWiseMmaOpTraits::BlockM; - constexpr static uint32_t BlockN = BlockWiseMmaOpTraits::BlockN; - constexpr static uint32_t BlockK = BlockWiseMmaOpTraits::BlockK; + // Fragment counts for decomposition + constexpr static uint32_t FragsM = WaveTileM / FragM; + constexpr static uint32_t FragsN = WaveTileN / FragN; + constexpr static uint32_t FragsK = WaveTileK / FragK; + constexpr static uint32_t FragsC = FragsM * FragsN; - // Block counts for decomposition - constexpr static uint32_t BlocksM = FragM / BlockM; - constexpr static uint32_t BlocksN = FragN / BlockN; - constexpr static uint32_t BlocksK = FragK / BlockK; - constexpr static uint32_t BlocksC = BlocksM * BlocksN; + // Vector types for packed registers in each fragment + using AVecType = typename MmaOp::AVecType; + using BVecType = typename MmaOp::BVecType; + using CVecType = typename MmaOp::CVecType; - // Vector types for packed registers in each block - using AVecType = typename BlockWiseMmaOpTraits::AVecType; - using BVecType = typename BlockWiseMmaOpTraits::BVecType; - using CVecType = typename BlockWiseMmaOpTraits::CVecType; - - // Buffer types for fragments - using ABufferType = AVecType[BlocksM][BlocksK]; - using BBufferType = BVecType[BlocksN][BlocksK]; - using CBufferType = CVecType[BlocksM][BlocksN]; + // Buffer types for WaveTiles + using ABufferType = AVecType[FragsM][FragsK]; + using BBufferType = BVecType[FragsN][FragsK]; + using CBufferType = CVecType[FragsM][FragsN]; // Transforms using ATransform = typename MmaTransforms::ATransform; @@ -108,20 +103,20 @@ struct WaveWiseMma using DTransform = typename MmaTransforms::DTransform; // Sanity checks - static_assert(FragM >= BlockM, "FragM must be larger than BlockM"); - static_assert(FragN >= BlockN, "FragN must be larger than BlockN"); - static_assert(FragK >= BlockK, "FragK must be larger than BlockK"); - static_assert(FragM % BlockM == 0u, "FragM must be a multiple of BlockM"); - static_assert(FragN % BlockN == 0u, "FragN must be a multiple of BlockN"); - static_assert(FragK % BlockK == 0u, "FragK must be a multiple of BlockK"); + static_assert(WaveTileM >= FragM, "WaveTileM must be larger than FragM"); + static_assert(WaveTileN >= FragN, "WaveTileN must be larger than FragN"); + static_assert(WaveTileK >= FragK, "WaveTileK must be larger than FragK"); + static_assert(WaveTileM % FragM == 0u, "WaveTileM must be a multiple of FragM"); + static_assert(WaveTileN % FragN == 0u, "WaveTileN must be a multiple of FragN"); + static_assert(WaveTileK % FragK == 0u, "WaveTileK must be a multiple of FragK"); private: template CK_TILE_DEVICE static auto formatBuffer(SrcT const& inputBuffer) { // TODO: Implement formatting logic as needed. - // This is intended to convert input fragments to the native vector types - // required by the BlockWiseMma operation for iteration + // This is intended to convert input WaveTiles to the native vector types + // required by the FragWiseMma operation for iteration static_assert(sizeof(DstT) == sizeof(SrcT), "Size mismatch in formatBuffer"); return reinterpret_cast(inputBuffer); } @@ -130,16 +125,16 @@ struct WaveWiseMma CK_TILE_DEVICE static auto formatBuffer(SrcT& inputBuffer) { // TODO: Implement formatting logic as needed. - // This is intended to convert input fragments to the native vector types - // required by the BlockWiseMma operation for iteration + // This is intended to convert input WaveTiles to the native vector types + // required by the FragWiseMma operation for iteration static_assert(sizeof(DstT) == sizeof(SrcT), "Size mismatch in formatBuffer"); return reinterpret_cast(inputBuffer); } /*! @brief Execute Mma in row-major accumulation order. - * @tparam VecTA The input fragment A vector type - * @tparam VecTB The input fragment B vector type - * @tparam VecTC The input/output fragment C vector type + * @tparam VecTA The input WaveTile A vector type + * @tparam VecTB The input WaveTile B vector type + * @tparam VecTC The input/output WaveTile C vector type */ template CK_TILE_DEVICE static decltype(auto) exec_col_major(VecTA&& a, VecTB&& b, VecTC&& accum) @@ -153,35 +148,35 @@ struct WaveWiseMma auto b_frag = formatBuffer(BTransform::exec(b)); auto c_frag = formatBuffer(CTransform::exec(accum)); - // "Col-major" accumulation over the M-dimension blocks first. - // Pseudo code here, but we would basically iterate over the blocks in col-major order - for(uint32_t bn = 0u; bn < BlocksN; ++bn) + // "Col-major" accumulation over the M-dimension fragments first. + // Pseudo code here, but we would basically iterate over the fragments in col-major order + for(uint32_t bn = 0u; bn < FragsN; ++bn) { - for(uint32_t bm = 0u; bm < BlocksM; ++bm) + for(uint32_t bm = 0u; bm < FragsM; ++bm) { - for(uint32_t bk = 0u; bk < BlocksK; ++bk) + for(uint32_t bk = 0u; bk < FragsK; ++bk) { c_frag[bm][bn] = - BlockWiseMmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]); + FragWiseMmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]); } } } - // Convert native vector results back to the output fragment format + // Convert native vector results back to the output WaveTile format // and then return after we apply the final output transform. return DTransform::exec(formatBuffer>(c_frag)); } /*! @brief Execute Mma in row-major accumulation order. - * @tparam VecTA The input fragment A vector type - * @tparam VecTB The input fragment B vector type - * @tparam VecTC The input/output fragment C vector type + * @tparam VecTA The input WaveTile A vector type + * @tparam VecTB The input WaveTile B vector type + * @tparam VecTC The input/output WaveTile C vector type */ template CK_TILE_DEVICE static decltype(auto) exec_row_major(VecTA&& a, VecTB&& b, VecTC&& accum) { // We implement an example wave-tile pipeline here. - // First, we apply the necessary transforms to the input fragments, + // First, we apply the necessary transforms to the input WaveTiles, // then we convert the result into buffers of native vector formats // that we can easily index. Native vector formats are necessary inputs // to the given MmaOp exec function. @@ -189,32 +184,32 @@ struct WaveWiseMma auto b_frag = formatBuffer(BTransform::exec(b)); auto c_frag = formatBuffer(CTransform::exec(accum)); - // "Row-major" accumulation over the N-dimension blocks first. - // Pseudo code here, but we would basically iterate over the blocks in row-major order. - // We also have to ensure that the incoming vector fragments are converted to native vector - // types before passing to the BlockWiseMma exec function. - for(uint32_t bm = 0u; bm < BlocksM; ++bm) + // "Row-major" accumulation over the N-dimension fragments first. + // Pseudo code here, but we would basically iterate over the fragments in row-major order. + // We also have to ensure that the incoming vector WaveTiles are converted to native vector + // types before passing to the FragWiseMma exec function. + for(uint32_t bm = 0u; bm < FragsM; ++bm) { - for(uint32_t bn = 0u; bn < BlocksN; ++bn) + for(uint32_t bn = 0u; bn < FragsN; ++bn) { - for(uint32_t bk = 0u; bk < BlocksK; ++bk) + for(uint32_t bk = 0u; bk < FragsK; ++bk) { c_frag[bm][bn] = - BlockWiseMmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]); + FragWiseMmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]); } } } - // Convert native vector results back to the output fragment format + // Convert native vector results back to the output WaveTile format // and then return after we apply the final output transform. return DTransform::exec(formatBuffer>(c_frag)); } public: /*! @brief Forward to Mma operation with specified accumulation order. - * @tparam VecTA The input fragment A vector type - * @tparam VecTB The input fragment B vector type - * @tparam VecTC The input/output fragment C vector type + * @tparam VecTA The input WaveTile A vector type + * @tparam VecTB The input WaveTile B vector type + * @tparam VecTC The input/output WaveTile C vector type */ template CK_TILE_DEVICE static decltype(auto) exec(VecTA&& a, VecTB&& b, VecTC&& accum) diff --git a/include/ck_tile/core/arch/mma/mma_selector.hpp b/include/ck_tile/core/arch/mma/mma_selector.hpp index 1bb206283b..208b90d273 100644 --- a/include/ck_tile/core/arch/mma/mma_selector.hpp +++ b/include/ck_tile/core/arch/mma/mma_selector.hpp @@ -9,32 +9,31 @@ namespace ck_tile::core::arch::mma { /** * @class MmaDefaultSelector - * @brief Implements a default mma selector strategy for the current target architecture. - * This is simply intended as a default selection strategy for mma instruction operations. - * Given the particular datatypes and Fragment dimensions, the selector will attempt to - * select the instruction with the largest K dimension that is supported on the current target - * architecture. + * @brief Implements a default mma selector strategy for the current target architecture. This is + * simply intended as a default selection strategy for mma instruction operations. Given the + * particular datatypes and WaveTile dimensions, the selector will attempt to select the instruction + * with the largest K dimension that is supported on the current target architecture. * @tparam ADataType Data type of matrix A * @tparam BDataType Data type of matrix B * @tparam CDataType Data type of the accumulator - * @tparam FragM Fragment M dimension - * @tparam FragN Fragment N dimension - * @tparam FragK Fragment K dimension + * @tparam WaveTileM WaveTile M dimension + * @tparam WaveTileN WaveTile N dimension + * @tparam WaveTileK WaveTile K dimension * @tparam CompilerTarget The compiler target * @tparam OpFamily The MMA operation family * @tparam Enable SFINAE enabler - * @note Here we distinguish that Fragment MNK sizes from Block MNK sizes used in the actual MMA - * operation. Fragment sizes correspond to the overall tile size being computed, while Block sizes - * correspond to the size of the individual MMA instructions being used to compute the overall in - * block-wise. The Fragment sizes must be multiples of the Block sizes and in general larger than or - * equal to the Block sizes. + * @note Here we distinguish that WaveTile MNK sizes from Fragment MNK sizes used in the actual MMA + * operation. WaveTile sizes correspond to the overall tile size being computed, while Fragment + * sizes correspond to the size of the individual MMA instructions being used to compute the overall + * in fragment-wise. The WaveTile sizes must be multiples of the Fragment sizes and in general + * larger than or equal to the Fragment sizes. */ template @@ -46,9 +45,9 @@ struct MmaDefaultSelector using SelectedOp = amdgcn_mma, MmaOpFamily::UNDEFINED>; diff --git a/include/ck_tile/core/arch/mma/mma_traits.hpp b/include/ck_tile/core/arch/mma/mma_traits.hpp index 90cfd8aaf2..4f5f6ddbe3 100644 --- a/include/ck_tile/core/arch/mma/mma_traits.hpp +++ b/include/ck_tile/core/arch/mma/mma_traits.hpp @@ -44,122 +44,64 @@ struct is_mma_op_supported static constexpr bool is_mma_op_supported_v = is_mma_op_supported::value; -/** - * @class MmaOpParams - * @brief Reflects the template parameters of a given MmaOp - * @tparam MmaOp The matrix multiply-accumulate operation type to check - */ -// TODO: c++20 template -template -struct MmaOpParams; - -#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER -#include +template +struct MmaOpTraits; /** - * @concept MmaOpParamsI - * @brief Expresses the required members for each MmaOp - */ -template -concept MmaOpParamsI = requires(MmaOpParams op) { - // Capture template parameters - typename MmaOpParams::ADataType; - typename MmaOpParams::BDataType; - typename MmaOpParams::CDataType; - typename MmaOpParams::CtrlFlags; - - { MmaOpParams::BlockM } -> std::convertible_to; - { MmaOpParams::BlockN } -> std::convertible_to; - { MmaOpParams::BlockK } -> std::convertible_to; - { MmaOpParams::GfxTargetId } -> std::convertible_to; - { MmaOpParams::Family } -> std::convertible_to; -}; - -#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER - -/** - * @struct MmaOpParams - * @brief Reflects the template parameters of a given MmaOp + * @struct MmaOpTraits + * @brief Gives additional traits and unexposed template parameters of a given MmaOp * @tparam ADataType_ Data type of matrix A * @tparam BDataType_ Data type of matrix B * @tparam CDataType_ Data type of the accumulator - * @tparam BlockM_ Size of the M dimension - * @tparam BlockN_ Size of the N dimension - * @tparam BlockK_ Size of the K dimension + * @tparam FragM_ Size of the M dimension + * @tparam FragN_ Size of the N dimension + * @tparam FragK_ Size of the K dimension * @tparam CtrlFlags_ Control flags for the MMA operation * @tparam CompilerTarget_ The compiler target */ template // TODO: c++20 amdgcn_target_arch_id CompilerTarget_> -struct MmaOpParams> { - // Capture incoming template parameters - using ADataType = ADataType_; - using BDataType = BDataType_; - using CDataType = CDataType_; - static constexpr uint32_t BlockM = BlockM_; - static constexpr uint32_t BlockN = BlockN_; - static constexpr uint32_t BlockK = BlockK_; - using CtrlFlags = CtrlFlags_; - using CompilerTarget = CompilerTarget_; - static constexpr auto MmaOpFamily = OpFamily_; + using MmaOp = amdgcn_mma; + + // Capture incoming template parameters not already in amdgcn + using CtrlFlags = CtrlFlags_; + using CompilerTarget = CompilerTarget_; // TODO c++20static constexpr amdgcn_target_arch_id GfxTargetId = CompilerTarget_; -}; - -/** - * @class MmaOpTraits - * @brief Reflects the template parameters and static members of a given MmaOp. - * @tparam MmaOp The matrix multiply-accumulate operation - */ -template -// TODO: c++20 template -// TODO: c++20 requires MmaOpParamsI> -struct MmaOpTraits : public MmaOpParams -{ - // Capture internal MmaOp static members - using OpType = typename MmaOp::OpType; - using AVecType = typename MmaOp::AVecType; - using BVecType = typename MmaOp::BVecType; - using CVecType = typename MmaOp::CVecType; - - static constexpr MmaOpFamily OpFamily = MmaOp::OpFamily; - - // Capture layout parameters - static constexpr index_t kAMBlock = MmaOp::kAMBlock; - static constexpr index_t kBNBlock = MmaOp::kBNBlock; - static constexpr index_t kAMLane = MmaOp::kAMLane; - static constexpr index_t kBNLane = MmaOp::kBNLane; - static constexpr index_t kABKLane = MmaOp::kABKLane; - static constexpr index_t kABKPerLane = MmaOp::kABKPerLane; - static constexpr index_t kCMLane = MmaOp::kCMLane; - static constexpr index_t kCNLane = MmaOp::kCNLane; - static constexpr index_t kCM0PerLane = MmaOp::kCM0PerLane; - static constexpr index_t kCM1PerLane = MmaOp::kCM1PerLane; // Additional traits to identify the type of MmaOp at compile time constexpr static bool IsMfma = is_mma_op_mfma_v; constexpr static bool IsWmma = is_mma_op_wmma_v; - constexpr static bool IsDense = OpFamily == MmaOpFamily::DENSE; - constexpr static bool IsSparse = OpFamily == MmaOpFamily::SPARSE; - constexpr static bool IsScale = OpFamily == MmaOpFamily::SCALE; + constexpr static bool IsDense = OpFamily_ == MmaOpFamily::DENSE; + constexpr static bool IsSparse = OpFamily_ == MmaOpFamily::SPARSE; + constexpr static bool IsScale = OpFamily_ == MmaOpFamily::SCALE; constexpr static bool IsSupported = - is_mma_op_supported_v && OpFamily != MmaOpFamily::UNDEFINED; + is_mma_op_supported_v && OpFamily_ != MmaOpFamily::UNDEFINED; }; } // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/sparse/mfma/selector.hpp b/include/ck_tile/core/arch/mma/sparse/mfma/selector.hpp index 8daea552b7..43ec457c3d 100644 --- a/include/ck_tile/core/arch/mma/sparse/mfma/selector.hpp +++ b/include/ck_tile/core/arch/mma/sparse/mfma/selector.hpp @@ -14,23 +14,24 @@ namespace ck_tile::core::arch::mma { /** * @class SparseMfmaDefaultSelector * @brief Implements a default sparse MFMA selector strategy. The SelectedOp can be unsupported. - * @tparam ADataType Data type of matrix A - * @tparam BDataType Data type of matrix B - * @tparam CDataType Data type of the accumulator - * @tparam BlockM Size of the M dimension - * @tparam BlockN Size of the N dimension - * @tparam BlockKTest Size of the K dimension + * @tparam ADataType Data type of matrix A + * @tparam BDataType Data type of matrix B + * @tparam CDataType Data type of the accumulator + * @tparam WaveTileM Size of the M dimension + * @tparam WaveTileN Size of the N dimension + * @tparam WaveTileKTest Size of the K dimension * @tparam CompilerTarget The compiler target */ template // TODO: c++20 amdgcn_target_arch_id CompilerTarget> -// TODO: c++20 requires(is_target_arch_cdna(CompilerTarget) && is_power_of_two_integer(BlockKTest)) +// TODO: c++20 requires(is_target_arch_cdna(CompilerTarget) && +// is_power_of_two_integer(WaveTileKTest)) struct SparseMfmaDefaultSelector { private: @@ -38,26 +39,24 @@ struct SparseMfmaDefaultSelector using CandidateOp = amdgcn_mma; - using CandidateTraits = MmaOpTraits; - public: // If the candidate is supported (e.g., a backend implementation exists), then select it. // Otherwise, fall back to the unsupported pass-through implementation. - using SelectedOp = std::conditional_t::IsSupported, CandidateOp, amdgcn_mma, MmaOpFamily::UNDEFINED>>; @@ -67,21 +66,21 @@ struct SparseMfmaDefaultSelector * @struct MmaDefaultSelector * @brief Implements the CDNA default MMA selector strategy for sparse MFMA. * If no supported instruction is found, falls back to an unsupported pass-through implementation. - * @tparam ADataType Data type of matrix A - * @tparam BDataType Data type of matrix B - * @tparam CDataType Data type of the accumulator - * @tparam FragM Size of the M dimension of the fragment to decompose - * @tparam FragN Size of the N dimension of the fragment to decompose - * @tparam FragK Size of the K dimension of the fragment to decompose + * @tparam ADataType Data type of matrix A + * @tparam BDataType Data type of matrix B + * @tparam CDataType Data type of the accumulator + * @tparam WaveTileM Size of the M dimension of the WaveTile to decompose + * @tparam WaveTileN Size of the N dimension of the WaveTile to decompose + * @tparam WaveTileK Size of the K dimension of the WaveTile to decompose * @tparam CompilerTarget The compiler target - * @tparam OpFamily The MMA operation family + * @tparam OpFamily The MMA operation family */ template // TODO: c++20 amdgcn_target_arch_id CompilerTarget> @@ -89,9 +88,9 @@ template ::SelectedOp; - // Traits for each candidate - using CandidateTraits16x16 = MmaOpTraits; - using CandidateTraits32x32 = MmaOpTraits; - - // Check if each candidate is supported for the given fragment sizes - // For this case, we require the fragment sizes to be multiples of the MFMA shape - static constexpr bool IsSupported16x16 = CandidateTraits16x16::IsSupported && - (FragM % CandidateTraits16x16::BlockM == 0u) && - (FragN % CandidateTraits16x16::BlockN == 0u) && - (FragK % CandidateTraits16x16::BlockK == 0u); - static constexpr bool IsSupported32x32 = CandidateTraits32x32::IsSupported && - (FragM % CandidateTraits32x32::BlockM == 0u) && - (FragN % CandidateTraits32x32::BlockN == 0u) && - (FragK % CandidateTraits32x32::BlockK == 0u); + // Check if each candidate is supported for the given WaveTile sizes + // For this case, we require the WaveTile sizes to be multiples of the MFMA shape + static constexpr bool IsSupported16x16 = + MmaOpTraits::IsSupported && (WaveTileM % CandidateOp16x16::kM == 0u) && + (WaveTileN % CandidateOp16x16::kN == 0u) && (WaveTileK % CandidateOp16x16::kK == 0u); + static constexpr bool IsSupported32x32 = + MmaOpTraits::IsSupported && (WaveTileM % CandidateOp32x32::kM == 0u) && + (WaveTileN % CandidateOp32x32::kN == 0u) && (WaveTileK % CandidateOp32x32::kK == 0u); public: - // Select the largest supported MFMA operation for the given fragment shape + // Select the largest supported MFMA operation for the given WaveTile shape using SelectedOp = std::conditional_t // TODO: c++20 requires template -struct amdgcn_mma< - fp16_t, - fp16_t, - fp32_t, - 16u, - 16u, - 32u, - CtrlFlags, - CompilerTarget, - MmaOpFamily::SPARSE, - std::enable_if_t> +// clang-format off +// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | +struct amdgcn_mma> +: amdgcn_mma_base +// clang-format on { - using OpType = MfmaOp; - static constexpr MmaOpFamily OpFamily = MmaOpFamily::SPARSE; - - static constexpr index_t ABVecN = 8; - - using AVecType = ext_vector_t; - using BVecType = ext_vector_t; - using CVecType = ext_vector_t; - - static constexpr index_t kAMBlock = 1; - static constexpr index_t kBNBlock = 1; - - static constexpr index_t kAMLane = 16; - static constexpr index_t kBNLane = 16; - static constexpr index_t kABKLane = 4; - static constexpr index_t kABKPerLane = 8; - - static constexpr index_t kCMLane = 4; - static constexpr index_t kCNLane = 16; - static constexpr index_t kCM0PerLane = 1; - static constexpr index_t kCM1PerLane = 4; - - static constexpr index_t kCompressionRatio = 2; - CK_TILE_DEVICE static auto exec(AVecType& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType { - static constexpr index_t CompressedSize = ABVecN / kCompressionRatio; - using AVecCompressed = ext_vector_t; + static constexpr index_t ABVecN = vector_traits::vector_size; + static constexpr index_t kCompressionRatio = 2; + static constexpr index_t CompressedSize = ABVecN / kCompressionRatio; + using AVecCompressed = ext_vector_t; + static_assert(CompressedSize == 4); // TODO: Compressing A on-the-fly should be OK for now, but we need to validate // and evaluate changing this to a transform at a higher level. // aVec not being const can cause problems when running multiple intrinsics. - const int32_t idx = ck_tile::compress_a_impl(aVec); + const uint32_t idx = ck_tile::compress_a_impl(aVec); const AVecCompressed a_vec_pruned = {aVec[0], aVec[1], aVec[2], aVec[3]}; diff --git a/include/ck_tile/core/arch/mma/sparse/wmma/selector.hpp b/include/ck_tile/core/arch/mma/sparse/wmma/selector.hpp index 802e132083..8b4803b6bf 100644 --- a/include/ck_tile/core/arch/mma/sparse/wmma/selector.hpp +++ b/include/ck_tile/core/arch/mma/sparse/wmma/selector.hpp @@ -13,23 +13,24 @@ namespace ck_tile::core::arch::mma { /** * @class SparseWmmaDefaultSelector * @brief Implements a default sparse WMMA selector strategy. The SelectedOp can be unsupported. - * @tparam ADataType Data type of matrix A - * @tparam BDataType Data type of matrix B - * @tparam CDataType Data type of the accumulator - * @tparam BlockM Size of the M dimension - * @tparam BlockN Size of the N dimension - * @tparam BlockKTest Size of the K dimension + * @tparam ADataType Data type of matrix A + * @tparam BDataType Data type of matrix B + * @tparam CDataType Data type of the accumulator + * @tparam WaveTileM Size of the M dimension + * @tparam WaveTileN Size of the N dimension + * @tparam WaveTileKTest Size of the K dimension * @tparam CompilerTarget The compiler target */ template // TODO: c++20 amdgcn_target_arch_id CompilerTarget> -// TODO: c++20 requires(is_target_arch_rdna(CompilerTarget) && is_power_of_two_integer(BlockKTest)) +// TODO: c++20 requires(is_target_arch_rdna(CompilerTarget) && +// is_power_of_two_integer(WaveTileKTest)) struct SparseWmmaDefaultSelector { private: @@ -37,9 +38,9 @@ struct SparseWmmaDefaultSelector using CandidateOp = amdgcn_mma; @@ -54,9 +55,9 @@ struct SparseWmmaDefaultSelector amdgcn_mma, MmaOpFamily::UNDEFINED>>; @@ -66,21 +67,21 @@ struct SparseWmmaDefaultSelector * @struct MmaDefaultSelector * @brief Implements the RDNA default MMA selector strategy for sparse WMMA. * If no supported instruction is found, falls back to an unsupported pass-through implementation. - * @tparam ADataType Data type of matrix A - * @tparam BDataType Data type of matrix B - * @tparam CDataType Data type of the accumulator - * @tparam FragM Size of the M dimension of the fragment to decompose - * @tparam FragN Size of the N dimension of the fragment to decompose - * @tparam FragK Size of the K dimension of the fragment to decompose + * @tparam ADataType Data type of matrix A + * @tparam BDataType Data type of matrix B + * @tparam CDataType Data type of the accumulator + * @tparam WaveTileM Size of the M dimension of the WaveTile to decompose + * @tparam WaveTileN Size of the N dimension of the WaveTile to decompose + * @tparam WaveTileK Size of the K dimension of the WaveTile to decompose * @tparam CompilerTarget The compiler target - * @tparam OpFamily The MMA operation family + * @tparam OpFamily The MMA operation family */ template // TODO: c++20 amdgcn_target_arch_id CompilerTarget> @@ -88,9 +89,9 @@ template , @@ -116,18 +117,14 @@ struct MmaDefaultSelector::SelectedOp; - // Traits for each candidate - using CandidateTraits16x16 = MmaOpTraits; - - // Check if each candidate is supported for the given fragment sizes - // For this case, we require the fragment sizes to be multiples of the WMMA shape - static constexpr bool IsSupported16x16 = CandidateTraits16x16::IsSupported && - (FragM % CandidateTraits16x16::BlockM == 0u) && - (FragN % CandidateTraits16x16::BlockN == 0u) && - (FragK % CandidateTraits16x16::BlockK == 0u); + // Check if each candidate is supported for the given WaveTile sizes + // For this case, we require the WaveTile sizes to be multiples of the WMMA shape + static constexpr bool IsSupported16x16 = + MmaOpTraits::IsSupported && (WaveTileM % CandidateOp16x16::kM == 0u) && + (WaveTileN % CandidateOp16x16::kN == 0u) && (WaveTileK % CandidateOp16x16::kK == 0u); public: - // Select the largest supported WMMA operation for the given fragment shape + // Select the largest supported WMMA operation for the given WaveTile shape using SelectedOp = std::conditional_t; }; diff --git a/include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp b/include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp index c0d0e4169a..7981fd91aa 100644 --- a/include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp +++ b/include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp @@ -15,51 +15,24 @@ namespace ck_tile::core::arch::mma { // TODO: c++20 template // TODO: c++20 requires template -struct amdgcn_mma> +// clang-format off +struct amdgcn_mma> +: amdgcn_mma_base +// clang-format on { - using OpType = WmmaOp; - static constexpr MmaOpFamily OpFamily = MmaOpFamily::SPARSE; - - static constexpr index_t ABVecN = 16; - - using AVecType = ext_vector_t; - using BVecType = ext_vector_t; - using CVecType = ext_vector_t; - - static constexpr index_t kAMBlock = 1; - static constexpr index_t kBNBlock = 1; - - static constexpr index_t kAMLane = 16; - static constexpr index_t kBNLane = 16; - static constexpr index_t kABKLane = 4; - static constexpr index_t kABKPerLane = 8; - - static constexpr index_t kCMLane = 4; - static constexpr index_t kCNLane = 16; - static constexpr index_t kCM0PerLane = 1; - static constexpr index_t kCM1PerLane = 4; - - static constexpr index_t kCompressionRatio = 2; - CK_TILE_DEVICE static auto exec(AVecType& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType { - static constexpr index_t CompressedSize = ABVecN / kCompressionRatio; - using AVecCompressed = ext_vector_t; + static constexpr index_t ABVecN = vector_traits::vector_size; + static constexpr index_t kCompressionRatio = 2; + static constexpr index_t CompressedSize = ABVecN / kCompressionRatio; + using AVecCompressed = ext_vector_t; + static_assert(CompressedSize == 8); // TODO: Compressing A on-the-fly should be OK for now, but we need to validate // and evaluate changing this to a transform at a higher level. // aVec not being const can cause problems when running multiple intrinsics. - const int32_t idx = ::ck_tile::compress_a_impl(aVec); + const uint32_t idx = ck_tile::compress_a_impl(aVec); const AVecCompressed a_vec_pruned = { aVec[0], aVec[1], aVec[2], aVec[3], aVec[4], aVec[5], aVec[6], aVec[7]}; diff --git a/include/ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp b/include/ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp index 568a55c659..c86190573e 100644 --- a/include/ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp +++ b/include/ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp @@ -70,38 +70,12 @@ struct DefaultWmmaCtrlFlags // TODO: c++20 template // TODO: c++20 requires template -struct amdgcn_mma()>> +// clang-format off +// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | +struct amdgcn_mma()>> +: amdgcn_mma_base +// clang-format on { - // Wmma operation type - using OpType = WmmaOp; - static constexpr MmaOpFamily OpFamily = MmaOpFamily::DENSE; - - // Register types (duplicated input / b32 accum) - using AVecType = ext_vector_t; - using BVecType = ext_vector_t; - using CVecType = ext_vector_t; - - // Layout constants - static constexpr index_t kAMBlock = 1; - static constexpr index_t kBNBlock = 1; - static constexpr index_t kAMLane = 16; - static constexpr index_t kBNLane = 16; - static constexpr index_t kABKLane = 8; - static constexpr index_t kABKPerLane = 8; - static constexpr index_t kCMLane = 2; - static constexpr index_t kCNLane = 2; - static constexpr index_t kCM0PerLane = 4; - static constexpr index_t kCM1PerLane = 1; - CK_TILE_DEVICE static auto exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType { diff --git a/include/ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp b/include/ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp index f047862a06..0a74bf8d65 100644 --- a/include/ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp +++ b/include/ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp @@ -30,38 +30,12 @@ namespace ck_tile::core::arch::mma { // TODO: c++20 template // TODO: c++20 requires template -struct amdgcn_mma> +// clang-format off +// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | +struct amdgcn_mma> +: amdgcn_mma_base +// clang-format on { - // Wmma operation type - using OpType = WmmaOp; - static constexpr MmaOpFamily OpFamily = MmaOpFamily::DENSE; - - // Register types - using AVecType = ext_vector_t; - using BVecType = ext_vector_t; - using CVecType = ext_vector_t; - - // Layout constants - static constexpr index_t kAMBlock = 1; - static constexpr index_t kBNBlock = 1; - static constexpr index_t kAMLane = 16; - static constexpr index_t kBNLane = 16; - static constexpr index_t kABKLane = 8; - static constexpr index_t kABKPerLane = 8; - static constexpr index_t kCMLane = 2; - static constexpr index_t kCNLane = 2; - static constexpr index_t kCM0PerLane = 4; - static constexpr index_t kCM1PerLane = 1; - CK_TILE_DEVICE static auto exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType { diff --git a/include/ck_tile/core/arch/mma/wmma/wmma_selector.hpp b/include/ck_tile/core/arch/mma/wmma/wmma_selector.hpp index 367aa2677f..f8616ad19c 100644 --- a/include/ck_tile/core/arch/mma/wmma/wmma_selector.hpp +++ b/include/ck_tile/core/arch/mma/wmma/wmma_selector.hpp @@ -14,24 +14,24 @@ namespace ck_tile::core::arch::mma { * @class WmmaDefaultSelector * @brief Implements a default WMMA selector strategy for gfx11/12 target architectures. * This implements the K dimension search strategy to find the largest supported WMMA - * instruction for the given M/N block sizes and datatypes. - * @tparam ADataType Data type of matrix A - * @tparam BDataType Data type of matrix B - * @tparam CDataType Data type of the accumulator - * @tparam BlockM Size of the M dimension - * @tparam BlockN Size of the N dimension - * @tparam BlockKTest Size of the K dimension + * instruction for the given M/N WaveTile sizes and datatypes. + * @tparam ADataType Data type of matrix A + * @tparam BDataType Data type of matrix B + * @tparam CDataType Data type of the accumulator + * @tparam WaveTileM Size of the M dimension + * @tparam WaveTileN Size of the N dimension + * @tparam WaveTileKTest Size of the K dimension * @tparam CompilerTarget The compiler target */ template // TODO: c++20 amdgcn_target_arch_id CompilerTarget> -// TODO: c++20 requires(is_rdna_arch_id(CompilerTarget) && is_power_of_two_integer(BlockKTest)) +// TODO: c++20 requires(is_rdna_arch_id(CompilerTarget) && is_power_of_two_integer(WaveTileKTest)) struct WmmaDefaultSelector { private: @@ -42,27 +42,25 @@ struct WmmaDefaultSelector using CandidateOp = amdgcn_mma; - using CandidateTraits = MmaOpTraits; - public: // If the candidate is supported (e.g., a backend implementation exists), then select it. - // Otherwise, test another smaller BlockK. If no existing implementations, we will get BlockK=0u - // and fall back to the unsupported pass-through implementation. - using SelectedOp = std::conditional_t::IsSupported, CandidateOp, typename WmmaDefaultSelector::SelectedOp>; }; @@ -72,21 +70,27 @@ struct WmmaDefaultSelector * This implements the K dimension == 1, which is the base case for the recursive K dimension * search. If no supported instruction is found, falls back to an unsupported pass-through * implementation. - * @tparam ADataType Data type of matrix A - * @tparam BDataType Data type of matrix B - * @tparam CDataType Data type of the accumulator - * @tparam BlockM Size of the M dimension - * @tparam BlockN Size of the N dimension + * @tparam ADataType Data type of matrix A + * @tparam BDataType Data type of matrix B + * @tparam CDataType Data type of the accumulator + * @tparam WaveTileM Size of the M dimension + * @tparam WaveTileN Size of the N dimension * @tparam CompilerTarget The compiler target */ template // TODO: c++20 amdgcn_target_arch_id GfxTargetId> -struct WmmaDefaultSelector +struct WmmaDefaultSelector { // By default, let's assume no special flags for WMMA using CtrlFlags = DefaultWmmaCtrlFlags; @@ -95,8 +99,8 @@ struct WmmaDefaultSelector // TODO: c++20 amdgcn_target_arch_id CompilerTarget> @@ -131,9 +135,9 @@ template , @@ -155,18 +159,14 @@ struct MmaDefaultSelector:: SelectedOp; - // Traits for each candidate - using CandidateTraits16x16 = MmaOpTraits; - - // Check if each candidate is supported for the given fragment sizes - // For this case, we require the fragment sizes to be multiples of the WMMA shape - static constexpr bool IsSupported16x16 = CandidateTraits16x16::IsSupported && - (FragM % CandidateTraits16x16::BlockM == 0u) && - (FragN % CandidateTraits16x16::BlockN == 0u) && - (FragK % CandidateTraits16x16::BlockK == 0u); + // Check if each candidate is supported for the given WaveTile sizes + // For this case, we require the WaveTile sizes to be multiples of the WMMA shape + static constexpr bool IsSupported16x16 = + MmaOpTraits::IsSupported && (WaveTileM % CandidateOp16x16::kM == 0u) && + (WaveTileN % CandidateOp16x16::kN == 0u) && (WaveTileK % CandidateOp16x16::kK == 0u); public: - // Select the largest supported WMMA operation for the given fragment shape + // Select the largest supported WMMA operation for the given WaveTile shape using SelectedOp = std::conditional_t; }; diff --git a/test/ck_tile/core/arch/mma/test_amdgcn_mma.cpp b/test/ck_tile/core/arch/mma/test_amdgcn_mma.cpp index 2ed8b96f19..865c3e1011 100644 --- a/test/ck_tile/core/arch/mma/test_amdgcn_mma.cpp +++ b/test/ck_tile/core/arch/mma/test_amdgcn_mma.cpp @@ -41,40 +41,12 @@ using enable_if_target_id_dummy_t = std::enable_if_t template -struct amdgcn_mma> +// clang-format off +// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | +struct amdgcn_mma> +: amdgcn_mma_base +// clang-format on { - // Mfma operation type - using OpType = DummyOpType; - static constexpr MmaOpFamily OpFamily = MmaOpFamily::DENSE; - - // Register types - using AVecType = ext_vector_t; - using BVecType = ext_vector_t; - using CVecType = ext_vector_t; - - // Layout constants - static constexpr index_t kAMBlock = 1; - static constexpr index_t kBNBlock = 2; - - static constexpr index_t kAMLane = 3; - static constexpr index_t kBNLane = 4; - static constexpr index_t kABKLane = 5; - static constexpr index_t kABKPerLane = 6; - - static constexpr index_t kCMLane = 7; - static constexpr index_t kCNLane = 8; - static constexpr index_t kCM0PerLane = 9; - static constexpr index_t kCM1PerLane = 10; - CK_TILE_DEVICE static CVecType exec(AVecType const& regsA, BVecType const& regsB, CVecType const& regsC) { @@ -88,30 +60,30 @@ template using DummyAmdgcnMma = amdgcn_mma; /*! @struct MmaDefaultSelector * @brief For dummy Id only, instantiate tests for both MFMA and WMMA selectors so we can them both - * @tparam ADataType Data type of matrix A - * @tparam BDataType Data type of matrix B - * @tparam CDataType Data type of the accumulator - * @tparam FragM Size of the M dimension of the fragment to decompose - * @tparam FragN Size of the N dimension of the fragment to decompose - * @tparam FragK Size of the K dimension of the fragment to decompose + * @tparam ADataType Data type of matrix A + * @tparam BDataType Data type of matrix B + * @tparam CDataType Data type of the accumulator + * @tparam WaveTileM Size of the M dimension of the WaveTile to decompose + * @tparam WaveTileN Size of the N dimension of the WaveTile to decompose + * @tparam WaveTileK Size of the K dimension of the WaveTile to decompose * @tparam CompilerTarget The compiler target - * @tparam OpFamily The MMA operation family + * @tparam OpFamily The MMA operation family */ template // TODO: c++20 amdgcn_target_arch_id CompilerTarget> @@ -119,9 +91,9 @@ template , @@ -145,25 +117,6 @@ TEST(TestAmdgcnMma, ArchSupported) (std::is_same::value)); // OpType is DummyOpType // Check OpFamily EXPECT_TRUE((is_mma_op_of_family_v)); - - // Check AVecType, BVecType, CVecType - EXPECT_TRUE((std::is_same>::value)); - EXPECT_TRUE((std::is_same>::value)); - EXPECT_TRUE((std::is_same>::value)); - - // Check layout constants - EXPECT_EQ(MmaOp::kAMBlock, 1); - EXPECT_EQ(MmaOp::kBNBlock, 2); - - EXPECT_EQ(MmaOp::kAMLane, 3); - EXPECT_EQ(MmaOp::kBNLane, 4); - EXPECT_EQ(MmaOp::kABKLane, 5); - EXPECT_EQ(MmaOp::kABKPerLane, 6); - - EXPECT_EQ(MmaOp::kCMLane, 7); - EXPECT_EQ(MmaOp::kCNLane, 8); - EXPECT_EQ(MmaOp::kCM0PerLane, 9); - EXPECT_EQ(MmaOp::kCM1PerLane, 10); } // Test case for unsupported architecture @@ -176,25 +129,6 @@ TEST(TestAmdgcnMma, ArchUnsupported) EXPECT_TRUE((std::is_same::value)); // OpFamily should be Undefined EXPECT_TRUE((is_mma_op_of_family_v)); - - // AVecType, BVecType, CVecType should match default - EXPECT_TRUE((std::is_same>::value)); - EXPECT_TRUE((std::is_same>::value)); - EXPECT_TRUE((std::is_same>::value)); - - // Layout constants should match default values (typically 0) - EXPECT_EQ(MmaOp::kAMBlock, 0); - EXPECT_EQ(MmaOp::kBNBlock, 0); - - EXPECT_EQ(MmaOp::kAMLane, 0); - EXPECT_EQ(MmaOp::kBNLane, 0); - EXPECT_EQ(MmaOp::kABKLane, 0); - EXPECT_EQ(MmaOp::kABKPerLane, 0); - - EXPECT_EQ(MmaOp::kCMLane, 0); - EXPECT_EQ(MmaOp::kCNLane, 0); - EXPECT_EQ(MmaOp::kCM0PerLane, 0); - EXPECT_EQ(MmaOp::kCM1PerLane, 0); } // Kernel to test amdgcn_mma::exec on device @@ -317,89 +251,26 @@ TEST(TestAmdgcnMma, ArchUnsupportedExecDeviceOutput) #include "ck_tile/core/arch/mma/mma_traits.hpp" -// Test MmaOpParams for supported DummyAmdgcnMma, including all member variables -TEST(TestAmdgcnMma, MmaOpParamsTraitsSupportedMembers) -{ - using MmaOp = DummyAmdgcnMma; - using Traits = MmaOpParams; - - // Check MmaOpParams members - EXPECT_TRUE((std::is_same::value)); - EXPECT_TRUE((std::is_same::value)); - EXPECT_TRUE((std::is_same::value)); - EXPECT_EQ(Traits::BlockM, 16u); - EXPECT_EQ(Traits::BlockN, 16u); - EXPECT_EQ(Traits::BlockK, 16u); - EXPECT_TRUE((std::is_same::value)); -} - -// Test MmaOpParams for unsupported DummyAmdgcnMma, including all member variables -TEST(TestAmdgcnMma, MmaOpParamsUnsupportedMembers) -{ - using MmaOp = DummyAmdgcnMma>; - using Traits = MmaOpParams; - - // Check MmaOpParams members - EXPECT_TRUE((std::is_same::value)); - EXPECT_TRUE((std::is_same::value)); - EXPECT_TRUE((std::is_same::value)); - EXPECT_EQ(Traits::BlockM, 16u); - EXPECT_EQ(Traits::BlockN, 16u); - EXPECT_EQ(Traits::BlockK, 16u); - EXPECT_TRUE((std::is_same::value)); -} - // Test MmaOpTraits for supported DummyAmdgcnMma, including all member variables TEST(TestAmdgcnMma, MmaOpTraitsSupportedMembers) { - using MmaOp = DummyAmdgcnMma; - using Traits = MmaOpTraits; + using MmaOp = DummyAmdgcnMma; // Check MmaOpTraits member variables - EXPECT_TRUE((std::is_same::value)); - EXPECT_TRUE((std::is_same>::value)); - EXPECT_TRUE((std::is_same>::value)); - EXPECT_TRUE((std::is_same>::value)); - EXPECT_EQ(Traits::kAMBlock, 1); - EXPECT_EQ(Traits::kBNBlock, 2); - EXPECT_EQ(Traits::kAMLane, 3); - EXPECT_EQ(Traits::kBNLane, 4); - EXPECT_EQ(Traits::kABKLane, 5); - EXPECT_EQ(Traits::kABKPerLane, 6); - EXPECT_EQ(Traits::kCMLane, 7); - EXPECT_EQ(Traits::kCNLane, 8); - EXPECT_EQ(Traits::kCM0PerLane, 9); - EXPECT_EQ(Traits::kCM1PerLane, 10); - EXPECT_FALSE(Traits::IsMfma); - EXPECT_FALSE(Traits::IsWmma); - EXPECT_TRUE(Traits::IsSupported); + EXPECT_FALSE(MmaOpTraits::IsMfma); + EXPECT_FALSE(MmaOpTraits::IsWmma); + EXPECT_TRUE(MmaOpTraits::IsSupported); } // Test MmaOpTraits for unsupported DummyAmdgcnMma, including all member variables TEST(TestAmdgcnMma, MmaOpTraitsUnsupportedMembers) { - using MmaOp = DummyAmdgcnMma>; - using Traits = MmaOpTraits; + using MmaOp = DummyAmdgcnMma>; // Check MmaOpTraits member variables - EXPECT_TRUE((std::is_same::value)); - EXPECT_TRUE((std::is_same>::value)); - EXPECT_TRUE((std::is_same>::value)); - EXPECT_TRUE((std::is_same>::value)); - EXPECT_EQ(Traits::OpFamily, MmaOpFamily::UNDEFINED); - EXPECT_EQ(Traits::kAMBlock, 0); - EXPECT_EQ(Traits::kBNBlock, 0); - EXPECT_EQ(Traits::kAMLane, 0); - EXPECT_EQ(Traits::kBNLane, 0); - EXPECT_EQ(Traits::kABKLane, 0); - EXPECT_EQ(Traits::kABKPerLane, 0); - EXPECT_EQ(Traits::kCMLane, 0); - EXPECT_EQ(Traits::kCNLane, 0); - EXPECT_EQ(Traits::kCM0PerLane, 0); - EXPECT_EQ(Traits::kCM1PerLane, 0); - EXPECT_FALSE(Traits::IsMfma); - EXPECT_FALSE(Traits::IsWmma); - EXPECT_FALSE(Traits::IsSupported); + EXPECT_FALSE(MmaOpTraits::IsMfma); + EXPECT_FALSE(MmaOpTraits::IsWmma); + EXPECT_FALSE(MmaOpTraits::IsSupported); } // Test MmaDefaultSelector for supported DummyAmdgcnMma @@ -440,11 +311,11 @@ TEST(TestAmdgcnMma, MmaDefaultSelectorUnsupported) EXPECT_FALSE(MmaOpTraits::IsSupported); } -// Test MmaDefaultSelector for supported DummyAmdgcnMma on fragment sizes other than 16x16x16 -// This tests that the selector can still pick the correct MMA op even if the fragment sizes differ -TEST(TestAmdgcnMma, MmaDefaultSelectorSupportedFragment) +// Test MmaDefaultSelector for supported DummyAmdgcnMma on WaveTile sizes other than 16x16x16 +// This tests that the selector can still pick the correct MMA op even if the WaveTile sizes differ +TEST(TestAmdgcnMma, MmaDefaultSelectorSupportedWaveTile) { - // Select indirectly with a fragment size of 256x128x64 + // Select indirectly with a WaveTile size of 256x128x64 using SelectedMma = MmaDefaultSelector::IsSupported); } -// Test MmaDefaultSelector for a different block size and supported arch -TEST(TestAmdgcnMma, MmaDefaultSelectorUnsupportedFragment) +// Test MmaDefaultSelector for a different WaveTile size and supported arch +TEST(TestAmdgcnMma, MmaDefaultSelectorUnsupportedWaveTile) { // This should fall back to unsupported since DummyAmdgcnMma only supports 16x16x16 using SelectedMma = MmaDefaultSelector + uint32_t WaveTileM, + uint32_t WaveTileN, + uint32_t WaveTileK> __global__ void test_accum_over_k(void* a, void* b, void* c, void* out) { using Selector = MmaDefaultSelector; - using MmaOp = typename Selector::SelectedOp; - using MmaTraits = MmaOpTraits; - + using MmaOp = typename Selector::SelectedOp; using CVecType = typename MmaOp::CVecType; - static constexpr uint32_t kIters = FragK / MmaTraits::BlockK; + static constexpr uint32_t kIters = WaveTileK / MmaOp::kK; // Initialize the accumulator CVecType result = *reinterpret_cast(c); - // Accumulate input AxB over FragK/BlockK iterations + // Accumulate input AxB over WaveTileK/FragK iterations for(uint32_t i = 0; i < kIters; ++i) { result = MmaOp::exec(*reinterpret_cast(a), @@ -561,16 +430,16 @@ TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_16x16x32_Real) using BType = fp16_t; using CType = fp32_t; - // Fragment size, also the expected block size from the selector. - // Note: Actual blockK might be slightly different due to hardware implementation, but the + // WaveTile size, also the expected fragment size (MmaTile) from the selector. + // Note: Actual FragK might be slightly different due to hardware implementation, but the // test_accum_over_k kernel will loop over the K dimension to ensure that the total K is // correct. - static constexpr uint32_t FragM = 16; - static constexpr uint32_t FragN = 16; - static constexpr uint32_t FragK = 32; - static constexpr uint32_t BlockM = FragM; - static constexpr uint32_t BlockN = FragN; - static constexpr uint32_t BlockK = FragK; + static constexpr uint32_t WaveTileM = 16; + static constexpr uint32_t WaveTileN = 16; + static constexpr uint32_t WaveTileK = 32; + static constexpr uint32_t FragM = WaveTileM; + static constexpr uint32_t FragN = WaveTileN; + static constexpr uint32_t FragK = WaveTileK; // Gfx11 has input data duplication and no accumulator padding (MultiplierC = 1) // TODO: c++20 use is_target_family_gfx11(currentArchId) @@ -581,9 +450,9 @@ TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_16x16x32_Real) uint32_t MultiplierC = 1; // The number of elements per thread - uint32_t AElements = BlockM * BlockK / deviceWarpSize * MultiplierA; - uint32_t BElements = BlockN * BlockK / deviceWarpSize * MultiplierB; - uint32_t CElements = BlockM * BlockN / deviceWarpSize * MultiplierC; + uint32_t AElements = FragM * FragK / deviceWarpSize * MultiplierA; + uint32_t BElements = FragN * FragK / deviceWarpSize * MultiplierB; + uint32_t CElements = FragM * FragN / deviceWarpSize * MultiplierC; uint32_t ASize = AElements * sizeof(AType); uint32_t BSize = BElements * sizeof(BType); @@ -611,16 +480,16 @@ TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_16x16x32_Real) HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice)); const auto wave_size = getDeviceWaveSize(); - test_accum_over_k + test_accum_over_k <<<1, wave_size>>>(d_a, d_b, d_c, d_out); HIP_CHECK_ERROR(hipDeviceSynchronize()); HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_out, CSize, hipMemcpyDeviceToHost)); - // Output should be FragK for all elements, because the inputs are all 1's + // Output should be WaveTileK for all elements, because the inputs are all 1's for(size_t i = 0; i < CElements; ++i) { - CType expected = static_cast(FragK); + CType expected = static_cast(WaveTileK); EXPECT_NEAR(h_out[i], expected, 1e-3); } @@ -633,7 +502,7 @@ TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_16x16x32_Real) // Do a live test. At minimum, there should be a solution on real hardware for F16_F16_F32_16x16x32 // The selector should be able to pick the correct MmaOp as a multiple of 16x16x32, even if the -// fragment sizes are larger than 16x16x32. This tests that the selector can handle larger fragment +// WaveTile sizes are larger than 16x16x32. This tests that the selector can handle larger WaveTile // sizes and still select the correct MmaOp. TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_112x112x128_Real) { @@ -659,19 +528,19 @@ TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_112x112x128_Real) using BType = fp16_t; using CType = fp32_t; - // Fragment size to test for decomposition. - // We expect the selector to pick a 16x16 block - static constexpr uint32_t FragM = 112; - static constexpr uint32_t FragN = 112; - static constexpr uint32_t FragK = 128; + // WaveTile size to test for decomposition. + // We expect the selector to pick a 16x16 WaveTile + static constexpr uint32_t WaveTileM = 112; + static constexpr uint32_t WaveTileN = 112; + static constexpr uint32_t WaveTileK = 128; - // The expected block size from the selector (multiple of 16). - // Note: Actual blockK might be slightly different due to hardware implementation, but the + // The expected fragment size from the selector (MmaTile, multiple of 16). + // Note: Actual FragK might be slightly different due to hardware implementation, but the // test_accum_over_k kernel will loop over the K dimension to ensure that the total K is // correct. - static constexpr uint32_t BlockM = 16; - static constexpr uint32_t BlockN = 16; - static constexpr uint32_t BlockK = 32; + static constexpr uint32_t FragM = 16; + static constexpr uint32_t FragN = 16; + static constexpr uint32_t FragK = 32; // Gfx11 has input data duplication and no accumulator padding (MultiplierC = 1) // TODO: c++20 use is_target_family_gfx11(currentArchId) @@ -682,9 +551,9 @@ TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_112x112x128_Real) uint32_t MultiplierC = 1; // The number of elements per thread - uint32_t AElements = BlockM * BlockK / deviceWarpSize * MultiplierA; - uint32_t BElements = BlockN * BlockK / deviceWarpSize * MultiplierB; - uint32_t CElements = BlockM * BlockN / deviceWarpSize * MultiplierC; + uint32_t AElements = FragM * FragK / deviceWarpSize * MultiplierA; + uint32_t BElements = FragN * FragK / deviceWarpSize * MultiplierB; + uint32_t CElements = FragM * FragN / deviceWarpSize * MultiplierC; uint32_t ASize = AElements * sizeof(AType); uint32_t BSize = BElements * sizeof(BType); @@ -712,16 +581,16 @@ TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_112x112x128_Real) HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice)); const auto wave_size = getDeviceWaveSize(); - test_accum_over_k + test_accum_over_k <<<1, wave_size>>>(d_a, d_b, d_c, d_out); HIP_CHECK_ERROR(hipDeviceSynchronize()); HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_out, CSize, hipMemcpyDeviceToHost)); - // Output should be FragK for all elements, because the inputs are all 1's + // Output should be WaveTileK for all elements, because the inputs are all 1's for(size_t i = 0; i < CElements; ++i) { - CType expected = static_cast(FragK); + CType expected = static_cast(WaveTileK); EXPECT_NEAR(h_out[i], expected, 1e-3); } diff --git a/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout.cpp b/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout.cpp index c411aaa8f4..b25d7191e2 100644 --- a/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout.cpp +++ b/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout.cpp @@ -55,17 +55,17 @@ namespace { * @tparam ADataType Data type of tensor A elements * @tparam BDataType Data type of tensor B elements * @tparam CDataType Data type of tensor C elements - * @tparam BlockM M-dimension of the MMA tile - * @tparam BlockN N-dimension of the MMA tile - * @tparam BlockK K-dimension of the MMA tile + * @tparam FragM M-dimension of the MMA tile + * @tparam FragN N-dimension of the MMA tile + * @tparam FragK K-dimension of the MMA tile * @tparam BlockSize HIP block size */ template struct MmaLayoutTestKernel { @@ -77,19 +77,18 @@ struct MmaLayoutTestKernel mma::MmaDefaultSelector; - using MmaOp = typename Selector::SelectedOp; - using MmaTraits = mma::MmaOpTraits; + using MmaOp = typename Selector::SelectedOp; - if constexpr(MmaTraits::IsSupported) + if constexpr(mma::MmaOpTraits::IsSupported) { - using AVecType = typename MmaTraits::AVecType; - using BVecType = typename MmaTraits::BVecType; - using CVecType = typename MmaTraits::CVecType; + using AVecType = typename MmaOp::AVecType; + using BVecType = typename MmaOp::BVecType; + using CVecType = typename MmaOp::CVecType; constexpr uint32_t a_vec_size = vector_traits::vector_size; constexpr uint32_t b_vec_size = vector_traits::vector_size; constexpr uint32_t c_vec_size = vector_traits::vector_size; @@ -102,9 +101,9 @@ struct MmaLayoutTestKernel // get (m, k, n), where "1" should be placed for this block const uint32_t case_idx = static_cast(blockIdx.x); - const uint32_t m = case_idx / (MmaTraits::BlockK * MmaTraits::BlockN); - const uint32_t k = (case_idx / MmaTraits::BlockN) % MmaTraits::BlockK; - const uint32_t n = case_idx % MmaTraits::BlockN; + const uint32_t m = case_idx / (MmaOp::kK * MmaOp::kN); + const uint32_t k = (case_idx / MmaOp::kN) % MmaOp::kK; + const uint32_t n = case_idx % MmaOp::kN; // place a single "1" in A/B fragments using (lane, vecIdx) -> (row, col) mapping for(uint32_t v = 0; v < a_vec_size; ++v) @@ -174,12 +173,12 @@ bool run_mma_layout_test() { using MmaOp = typename Selector::SelectedOp; using MmaTraits = mma::MmaOpTraits; - using ADataType = typename MmaTraits::ADataType; - using BDataType = typename MmaTraits::BDataType; - using CDataType = typename MmaTraits::CDataType; - constexpr uint32_t BlockM = MmaTraits::BlockM; - constexpr uint32_t BlockN = MmaTraits::BlockN; - constexpr uint32_t BlockK = MmaTraits::BlockK; + using ADataType = typename MmaOp::ADataType; + using BDataType = typename MmaOp::BDataType; + using CDataType = typename MmaOp::CDataType; + constexpr uint32_t FragM = MmaOp::kM; + constexpr uint32_t FragN = MmaOp::kN; + constexpr uint32_t FragK = MmaOp::kK; constexpr auto selector_target_id = MmaTraits::CompilerTarget::TARGET_ID; constexpr auto selector_wave_size = MmaTraits::CompilerTarget::WAVE_SIZE_ID; @@ -202,7 +201,7 @@ bool run_mma_layout_test() return false; } - constexpr uint32_t total_cases = BlockM * BlockK * BlockN; + constexpr uint32_t total_cases = FragM * FragK * FragN; ck_tile::DeviceMem d_errors(total_cases * sizeof(uint32_t)); std::vector h_errors(total_cases, 0u); @@ -213,9 +212,9 @@ bool run_mma_layout_test() using Kernel = MmaLayoutTestKernel(selector_wave_size)>; std::ignore = @@ -232,9 +231,9 @@ bool run_mma_layout_test() for(uint32_t case_idx = 0; case_idx < total_cases; ++case_idx) { - const uint32_t m = case_idx / (BlockK * BlockN); - const uint32_t k = (case_idx / BlockN) % BlockK; - const uint32_t n = case_idx % BlockN; + const uint32_t m = case_idx / (FragK * FragN); + const uint32_t k = (case_idx / FragN) % FragK; + const uint32_t n = case_idx % FragN; EXPECT_EQ(h_errors[case_idx], 0u) << "Mismatch for m=" << m << " k=" << k << " n=" << n; } diff --git a/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout_util.hpp b/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout_util.hpp index 850435d256..3b33fa56a6 100644 --- a/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout_util.hpp +++ b/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout_util.hpp @@ -6,7 +6,6 @@ #include "ck_tile/core.hpp" #include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/arch/mma/amdgcn_mma.hpp" -#include "ck_tile/core/arch/mma/mma_traits.hpp" #include "ck_tile/core/arch/mma/mma_selector.hpp" #include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/vector_type.hpp" @@ -93,12 +92,9 @@ struct RegisterMapTraits; - using MmaTraits = ck_tile::core::arch::mma::MmaOpTraits; - static constexpr index_t WaveSize = - static_cast(MmaTraits::CompilerTarget::WAVE_SIZE_ID); - static constexpr index_t AVecSize = vector_traits::vector_size; - static constexpr index_t BVecSize = vector_traits::vector_size; - static constexpr index_t CVecSize = vector_traits::vector_size; + static constexpr index_t AVecSize = vector_traits::vector_size; + static constexpr index_t BVecSize = vector_traits::vector_size; + static constexpr index_t CVecSize = vector_traits::vector_size; using kABPs2RHssMajor = sequence<2, 1>; using kABPs2RHssMinor = sequence<1, 0>; @@ -176,12 +172,9 @@ struct RegisterMapTraits; - using MmaTraits = ck_tile::core::arch::mma::MmaOpTraits; - static constexpr index_t WaveSize = - static_cast(MmaTraits::CompilerTarget::WAVE_SIZE_ID); - static constexpr index_t AVecSize = vector_traits::vector_size; - static constexpr index_t BVecSize = vector_traits::vector_size; - static constexpr index_t CVecSize = vector_traits::vector_size; + static constexpr index_t AVecSize = vector_traits::vector_size; + static constexpr index_t BVecSize = vector_traits::vector_size; + static constexpr index_t CVecSize = vector_traits::vector_size; using kABPs2RHssMajor = sequence<2, 1>; using kABPs2RHssMinor = sequence<0, 0>; @@ -192,29 +185,41 @@ struct RegisterMapTraits; using kCYs2RHsMinor = sequence<1>; - using AWarpDstrEncoding = tile_distribution_encoding< - sequence<1>, - tuple, sequence>, - tuple, - tuple, - kABYs2RHsMajor, - kABYs2RHsMinor>; + // TODO: remove these and fix constants in amdgcn_mma + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + static constexpr index_t kAMLane = 16; + static constexpr index_t kBNLane = 16; + static constexpr index_t kABKLane = 4; + static constexpr index_t kABKPerLane = 4; + static constexpr index_t kCMLane = 4; + static constexpr index_t kCNLane = 16; + static constexpr index_t kCM0PerLane = 1; + static constexpr index_t kCM1PerLane = 4; - using BWarpDstrEncoding = tile_distribution_encoding< - sequence<1>, - tuple, sequence>, - tuple, - tuple, - kABYs2RHsMajor, - kABYs2RHsMinor>; + using AWarpDstrEncoding = + tile_distribution_encoding, + tuple, sequence>, + tuple, + tuple, + kABYs2RHsMajor, + kABYs2RHsMinor>; - using CWarpDstrEncoding = tile_distribution_encoding< - sequence<1>, - tuple, sequence>, - tuple, - tuple, - kCYs2RHsMajor, - kCYs2RHsMinor>; + using BWarpDstrEncoding = + tile_distribution_encoding, + tuple, sequence>, + tuple, + tuple, + kABYs2RHsMajor, + kABYs2RHsMinor>; + + using CWarpDstrEncoding = + tile_distribution_encoding, + tuple, sequence>, + tuple, + tuple, + kCYs2RHsMajor, + kCYs2RHsMinor>; }; /** @@ -245,12 +250,9 @@ struct RegisterMapTraits; - using MmaTraits = ck_tile::core::arch::mma::MmaOpTraits; - static constexpr index_t WaveSize = - static_cast(MmaTraits::CompilerTarget::WAVE_SIZE_ID); - static constexpr index_t AVecSize = vector_traits::vector_size; - static constexpr index_t BVecSize = vector_traits::vector_size; - static constexpr index_t CVecSize = vector_traits::vector_size; + static constexpr index_t AVecSize = vector_traits::vector_size; + static constexpr index_t BVecSize = vector_traits::vector_size; + static constexpr index_t CVecSize = vector_traits::vector_size; using kABPs2RHssMajor = sequence<0, 1>; using kABPs2RHssMinor = sequence<0, 0>; diff --git a/test/ck_tile/core/arch/mma/test_amdgcn_sparse_mma.cpp b/test/ck_tile/core/arch/mma/test_amdgcn_sparse_mma.cpp index 735eac09b0..03abcb5772 100644 --- a/test/ck_tile/core/arch/mma/test_amdgcn_sparse_mma.cpp +++ b/test/ck_tile/core/arch/mma/test_amdgcn_sparse_mma.cpp @@ -144,32 +144,29 @@ TEST(SparseMMATrait, SparseSelector) template + uint32_t WaveTileM, + uint32_t WaveTileN, + uint32_t WaveTileK> __global__ void test_sparse_accum_over_k(void* a, void* b, void* c, void* out) { using CompilerTarget = decltype(get_compiler_target()); using Selector = MmaDefaultSelector; + using MmaOp = typename Selector::SelectedOp; + using CVecType = typename MmaOp::CVecType; - using MmaOp = typename Selector::SelectedOp; - using MmaTraits = MmaOpTraits; - - using CVecType = typename MmaOp::CVecType; - - static constexpr uint32_t kIters = FragK / MmaTraits::BlockK; + static constexpr uint32_t kIters = WaveTileK / MmaOp::kK; // Initialize the accumulator CVecType result = *reinterpret_cast(c); - // Accumulate input AxB over FragK/BlockK iterations + // Accumulate input AxB over WaveTileK/FragK iterations for(uint32_t i = 0; i < kIters; ++i) { result = MmaOp::exec(*reinterpret_cast(a), @@ -210,21 +207,21 @@ TEST(SparseMMATrait, MmaSelector_Sparse_F16_F16_F32_16x16x32_Real) using BType = fp16_t; using CType = fp32_t; - // Fragment size, also the expected block size from the selector. - // Note: Actual blockK might be slightly different due to hardware implementation, but the + // WaveTile size, also the expected fragment size (MmaTile) from the selector. + // Note: Actual FragK might be slightly different due to hardware implementation, but the // test_accum_over_k kernel will loop over the K dimension to ensure that the total K is // correct. - static constexpr uint32_t FragM = 16; - static constexpr uint32_t FragN = 16; - static constexpr uint32_t FragK = 32; - static constexpr uint32_t BlockM = FragM; - static constexpr uint32_t BlockN = FragN; - static constexpr uint32_t BlockK = FragK; + static constexpr uint32_t WaveTileM = 16; + static constexpr uint32_t WaveTileN = 16; + static constexpr uint32_t WaveTileK = 32; + static constexpr uint32_t FragM = WaveTileM; + static constexpr uint32_t FragN = WaveTileN; + static constexpr uint32_t FragK = WaveTileK; // The number of elements per thread - uint32_t AElements = BlockM * BlockK / deviceWarpSize; - uint32_t BElements = BlockN * BlockK / deviceWarpSize; - uint32_t CElements = BlockM * BlockN / deviceWarpSize; + uint32_t AElements = FragM * FragK / deviceWarpSize; + uint32_t BElements = FragN * FragK / deviceWarpSize; + uint32_t CElements = FragM * FragN / deviceWarpSize; uint32_t ASize = AElements * sizeof(AType); uint32_t BSize = BElements * sizeof(BType);