[CK_Tile] Refactor amdgcn_mma policy structs (#5272)

## Motivation
The point of this MR is to update the intrinsic layout parameters to
simplify them and make them more clear and flexible. Also, a number of
simple refactors were performed to reduce boilerplate and code
duplication.

## Technical Details
In CK Tile and old CK, the full set of information available in the
intrinsic wrappers, for WMMA and MFMA combined, would be something like:

```
// Basic info
using ADataType = void;
using BDataType = void;
using CDataType = void;

using AVecType = ext_vector_t<ADataType, 0>;
using BVecType = ext_vector_t<BDataType, 0>;
using CVecType = ext_vector_t<CDataType, 0>;

// Fragment sizes
static constexpr index_t kM;
static constexpr index_t kN;
static constexpr index_t kK;

// Layout parameters
static constexpr index_t kAMBlock;
static constexpr index_t kBNBlock;

static constexpr index_t kRepeat;
static constexpr index_t kAMLane;
static constexpr index_t kBNLane;
static constexpr index_t kABK0PerLane;
static constexpr index_t kABKLane;
static constexpr index_t kABK1PerLane;

static constexpr index_t kCMLane;
static constexpr index_t kCNLane;
static constexpr index_t kCM0PerLane;
static constexpr index_t kCM1PerLane;

using kABPs2RHssMajor = sequence<2, 1>;
using kABPs2RHssMinor = sequence<1, 0>;
using kABYs2RHsMajor  = sequence<2, 2>;
using kABYs2RHsMinor  = sequence<0, 2>;

using kCPs2RHssMajor = sequence<1, 2>;
using kCPs2RHssMinor = sequence<1, 0>;
using kCYs2RHsMajor  = sequence<1, 1>;
using kCYs2RHsMinor  = sequence<0, 2>;

using kCTPs2RHssMajor = sequence<2, 1>;
using kCTPs2RHssMinor = sequence<1, 0>;
using kCTYs2RHsMajor  = sequence<2, 2>;
using kCTYs2RHsMinor  = sequence<0, 2>;   
 ```
Note that on top of the intrinsic sizes, we have 12 layout parameters. I have reduced this in the new design to:

```
// Basic info
using ADataType = void;
using BDataType = void;
using CDataType = void;

// Fragment sizes
static constexpr index_t kM;
static constexpr index_t kN;
static constexpr index_t kK;

// Layout parameters
static constexpr index_t kABKPerLane; // K2 * K0, Always the same, even
for diff A / B layouts
static constexpr index_t kAKNumAccess; // K2
static constexpr index_t kARepeat; // Used for RDNA3 repeated inputs and
CDNA block hiding.
static constexpr index_t kBKNumAccess; // K2
static constexpr index_t kBRepeat; // Used for RDNA3 repeated inputs and
CDNA block hiding.
static constexpr index_t kCMPerLane;   // M2 * M0
static constexpr index_t kCMNumAccess; // M2

// Derived properties
using AVecType = ext_vector_t<ADataType, 0>;
using BVecType = ext_vector_t<BDataType, 0>;
using CVecType = ext_vector_t<CDataType, 0>;
```

Note that there are now only 7 layout parameters and no more dimensionality orderings. Believe it or not these 7 parameters are more general than the original 12, and can handle intrinsic and mid-level features that are currently awkward in CK Tile, like dealing with AttrNumAccess, different A / B layouts, more general block-hiding (currently very limited in CK tile), and future arch features.

Furthermore, the A, B and C vec types are now derived directly from the layout parameters to ensure internal consistency.

I added a detailed explanation of the new params in terms of register mappings at the top of amgcn_mma.hpp

Other refactorings I did in this MR:

- Make an amdgcn_mma_base struct to drastically reduce code duplication and potential bugs. Should also make auto-generating the amd_gcn specializations much easier.
- Simplify the MmaOpTraits significantly by only including those parameters that are not directly gettable from the MmaOp itself. This removes duplicated variables and simplifies higher level code.
- Remove overloaded "Block" term for intrinsic dimensions, and replace by "Frag" instead. Some spots were already using the term "Frag" for combined intrinsics, in which case I changed that term to "Chunk" instead.
- Remove some tests that had become somewhat pointless (setting variables and then checking their values immediately).

- [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Kiefer van Teutem
2026-03-20 16:07:00 +01:00
committed by GitHub
parent 0a3229ea22
commit e8f9bb0c19
17 changed files with 719 additions and 948 deletions

View File

@@ -41,40 +41,12 @@ using enable_if_target_id_dummy_t = std::enable_if_t<is_dummy_target(CompilerTar
// and can focus on testing the mechanism of selecting supported vs unsupported architectures.
// TODO: c++20 template <amdgcn_target_arch_id CompilerTarget>
template <typename CompilerTarget>
struct amdgcn_mma<fp32_t,
fp32_t,
fp32_t,
16u,
16u,
16u,
DummyCtrlFlags,
CompilerTarget,
MmaOpFamily::DENSE,
enable_if_target_id_dummy_t<CompilerTarget>>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp32_t, fp32_t, fp32_t, 8u, 8u, 8u, DummyCtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_id_dummy_t<CompilerTarget>>
: amdgcn_mma_base<fp32_t, fp32_t, fp32_t, 8u, 8u, 8u, 64u, 1, 1, 1, 1, 1, 1, 1, DummyOpType, MmaOpFamily::DENSE>
// clang-format on
{
// Mfma operation type
using OpType = DummyOpType;
static constexpr MmaOpFamily OpFamily = MmaOpFamily::DENSE;
// Register types
using AVecType = ext_vector_t<fp32_t, 4>;
using BVecType = ext_vector_t<fp32_t, 4>;
using CVecType = ext_vector_t<fp32_t, 4>;
// Layout constants
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 2;
static constexpr index_t kAMLane = 3;
static constexpr index_t kBNLane = 4;
static constexpr index_t kABKLane = 5;
static constexpr index_t kABKPerLane = 6;
static constexpr index_t kCMLane = 7;
static constexpr index_t kCNLane = 8;
static constexpr index_t kCM0PerLane = 9;
static constexpr index_t kCM1PerLane = 10;
CK_TILE_DEVICE static CVecType
exec(AVecType const& regsA, BVecType const& regsB, CVecType const& regsC)
{
@@ -88,30 +60,30 @@ template <typename CompilerTarget>
using DummyAmdgcnMma = amdgcn_mma<fp32_t,
fp32_t,
fp32_t,
16u,
16u,
16u,
8u,
8u,
8u,
DummyCtrlFlags,
CompilerTarget,
MmaOpFamily::DENSE>;
/*! @struct MmaDefaultSelector
* @brief For dummy Id only, instantiate tests for both MFMA and WMMA selectors so we can them both
* @tparam ADataType Data type of matrix A
* @tparam BDataType Data type of matrix B
* @tparam CDataType Data type of the accumulator
* @tparam FragM Size of the M dimension of the fragment to decompose
* @tparam FragN Size of the N dimension of the fragment to decompose
* @tparam FragK Size of the K dimension of the fragment to decompose
* @tparam ADataType Data type of matrix A
* @tparam BDataType Data type of matrix B
* @tparam CDataType Data type of the accumulator
* @tparam WaveTileM Size of the M dimension of the WaveTile to decompose
* @tparam WaveTileN Size of the N dimension of the WaveTile to decompose
* @tparam WaveTileK Size of the K dimension of the WaveTile to decompose
* @tparam CompilerTarget The compiler target
* @tparam OpFamily The MMA operation family
* @tparam OpFamily The MMA operation family
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
uint32_t FragM,
uint32_t FragN,
uint32_t FragK,
uint32_t WaveTileM,
uint32_t WaveTileN,
uint32_t WaveTileK,
typename CompilerTarget,
MmaOpFamily OpFamily>
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
@@ -119,9 +91,9 @@ template <typename ADataType,
struct MmaDefaultSelector<ADataType,
BDataType,
CDataType,
FragM,
FragN,
FragK,
WaveTileM,
WaveTileN,
WaveTileK,
CompilerTarget,
OpFamily,
enable_if_all<enable_if_target_id_dummy_t<CompilerTarget>,
@@ -145,25 +117,6 @@ TEST(TestAmdgcnMma, ArchSupported)
(std::is_same<typename MmaOp::OpType, DummyOpType>::value)); // OpType is DummyOpType
// Check OpFamily
EXPECT_TRUE((is_mma_op_of_family_v<MmaOpFamily::DENSE, MmaOp>));
// Check AVecType, BVecType, CVecType
EXPECT_TRUE((std::is_same<typename MmaOp::AVecType, ext_vector_t<fp32_t, 4>>::value));
EXPECT_TRUE((std::is_same<typename MmaOp::BVecType, ext_vector_t<fp32_t, 4>>::value));
EXPECT_TRUE((std::is_same<typename MmaOp::CVecType, ext_vector_t<fp32_t, 4>>::value));
// Check layout constants
EXPECT_EQ(MmaOp::kAMBlock, 1);
EXPECT_EQ(MmaOp::kBNBlock, 2);
EXPECT_EQ(MmaOp::kAMLane, 3);
EXPECT_EQ(MmaOp::kBNLane, 4);
EXPECT_EQ(MmaOp::kABKLane, 5);
EXPECT_EQ(MmaOp::kABKPerLane, 6);
EXPECT_EQ(MmaOp::kCMLane, 7);
EXPECT_EQ(MmaOp::kCNLane, 8);
EXPECT_EQ(MmaOp::kCM0PerLane, 9);
EXPECT_EQ(MmaOp::kCM1PerLane, 10);
}
// Test case for unsupported architecture
@@ -176,25 +129,6 @@ TEST(TestAmdgcnMma, ArchUnsupported)
EXPECT_TRUE((std::is_same<typename MmaOp::OpType, Unsupported>::value));
// OpFamily should be Undefined
EXPECT_TRUE((is_mma_op_of_family_v<MmaOpFamily::UNDEFINED, MmaOp>));
// AVecType, BVecType, CVecType should match default
EXPECT_TRUE((std::is_same<typename MmaOp::AVecType, ext_vector_t<fp32_t, 1>>::value));
EXPECT_TRUE((std::is_same<typename MmaOp::BVecType, ext_vector_t<fp32_t, 1>>::value));
EXPECT_TRUE((std::is_same<typename MmaOp::CVecType, ext_vector_t<fp32_t, 1>>::value));
// Layout constants should match default values (typically 0)
EXPECT_EQ(MmaOp::kAMBlock, 0);
EXPECT_EQ(MmaOp::kBNBlock, 0);
EXPECT_EQ(MmaOp::kAMLane, 0);
EXPECT_EQ(MmaOp::kBNLane, 0);
EXPECT_EQ(MmaOp::kABKLane, 0);
EXPECT_EQ(MmaOp::kABKPerLane, 0);
EXPECT_EQ(MmaOp::kCMLane, 0);
EXPECT_EQ(MmaOp::kCNLane, 0);
EXPECT_EQ(MmaOp::kCM0PerLane, 0);
EXPECT_EQ(MmaOp::kCM1PerLane, 0);
}
// Kernel to test amdgcn_mma::exec on device
@@ -317,89 +251,26 @@ TEST(TestAmdgcnMma, ArchUnsupportedExecDeviceOutput)
#include "ck_tile/core/arch/mma/mma_traits.hpp"
// Test MmaOpParams for supported DummyAmdgcnMma, including all member variables
TEST(TestAmdgcnMma, MmaOpParamsTraitsSupportedMembers)
{
using MmaOp = DummyAmdgcnMma<DummyCompilerTarget>;
using Traits = MmaOpParams<MmaOp>;
// Check MmaOpParams members
EXPECT_TRUE((std::is_same<typename Traits::ADataType, fp32_t>::value));
EXPECT_TRUE((std::is_same<typename Traits::BDataType, fp32_t>::value));
EXPECT_TRUE((std::is_same<typename Traits::CDataType, fp32_t>::value));
EXPECT_EQ(Traits::BlockM, 16u);
EXPECT_EQ(Traits::BlockN, 16u);
EXPECT_EQ(Traits::BlockK, 16u);
EXPECT_TRUE((std::is_same<typename Traits::CtrlFlags, DummyCtrlFlags>::value));
}
// Test MmaOpParams for unsupported DummyAmdgcnMma, including all member variables
TEST(TestAmdgcnMma, MmaOpParamsUnsupportedMembers)
{
using MmaOp = DummyAmdgcnMma<amdgcn_target<>>;
using Traits = MmaOpParams<MmaOp>;
// Check MmaOpParams members
EXPECT_TRUE((std::is_same<typename Traits::ADataType, fp32_t>::value));
EXPECT_TRUE((std::is_same<typename Traits::BDataType, fp32_t>::value));
EXPECT_TRUE((std::is_same<typename Traits::CDataType, fp32_t>::value));
EXPECT_EQ(Traits::BlockM, 16u);
EXPECT_EQ(Traits::BlockN, 16u);
EXPECT_EQ(Traits::BlockK, 16u);
EXPECT_TRUE((std::is_same<typename Traits::CtrlFlags, DummyCtrlFlags>::value));
}
// Test MmaOpTraits for supported DummyAmdgcnMma, including all member variables
TEST(TestAmdgcnMma, MmaOpTraitsSupportedMembers)
{
using MmaOp = DummyAmdgcnMma<DummyCompilerTarget>;
using Traits = MmaOpTraits<MmaOp>;
using MmaOp = DummyAmdgcnMma<DummyCompilerTarget>;
// Check MmaOpTraits member variables
EXPECT_TRUE((std::is_same<typename Traits::OpType, DummyOpType>::value));
EXPECT_TRUE((std::is_same<typename Traits::AVecType, ext_vector_t<fp32_t, 4>>::value));
EXPECT_TRUE((std::is_same<typename Traits::BVecType, ext_vector_t<fp32_t, 4>>::value));
EXPECT_TRUE((std::is_same<typename Traits::CVecType, ext_vector_t<fp32_t, 4>>::value));
EXPECT_EQ(Traits::kAMBlock, 1);
EXPECT_EQ(Traits::kBNBlock, 2);
EXPECT_EQ(Traits::kAMLane, 3);
EXPECT_EQ(Traits::kBNLane, 4);
EXPECT_EQ(Traits::kABKLane, 5);
EXPECT_EQ(Traits::kABKPerLane, 6);
EXPECT_EQ(Traits::kCMLane, 7);
EXPECT_EQ(Traits::kCNLane, 8);
EXPECT_EQ(Traits::kCM0PerLane, 9);
EXPECT_EQ(Traits::kCM1PerLane, 10);
EXPECT_FALSE(Traits::IsMfma);
EXPECT_FALSE(Traits::IsWmma);
EXPECT_TRUE(Traits::IsSupported);
EXPECT_FALSE(MmaOpTraits<MmaOp>::IsMfma);
EXPECT_FALSE(MmaOpTraits<MmaOp>::IsWmma);
EXPECT_TRUE(MmaOpTraits<MmaOp>::IsSupported);
}
// Test MmaOpTraits for unsupported DummyAmdgcnMma, including all member variables
TEST(TestAmdgcnMma, MmaOpTraitsUnsupportedMembers)
{
using MmaOp = DummyAmdgcnMma<amdgcn_target<>>;
using Traits = MmaOpTraits<MmaOp>;
using MmaOp = DummyAmdgcnMma<amdgcn_target<>>;
// Check MmaOpTraits member variables
EXPECT_TRUE((std::is_same<typename Traits::OpType, Unsupported>::value));
EXPECT_TRUE((std::is_same<typename Traits::AVecType, ext_vector_t<fp32_t, 1>>::value));
EXPECT_TRUE((std::is_same<typename Traits::BVecType, ext_vector_t<fp32_t, 1>>::value));
EXPECT_TRUE((std::is_same<typename Traits::CVecType, ext_vector_t<fp32_t, 1>>::value));
EXPECT_EQ(Traits::OpFamily, MmaOpFamily::UNDEFINED);
EXPECT_EQ(Traits::kAMBlock, 0);
EXPECT_EQ(Traits::kBNBlock, 0);
EXPECT_EQ(Traits::kAMLane, 0);
EXPECT_EQ(Traits::kBNLane, 0);
EXPECT_EQ(Traits::kABKLane, 0);
EXPECT_EQ(Traits::kABKPerLane, 0);
EXPECT_EQ(Traits::kCMLane, 0);
EXPECT_EQ(Traits::kCNLane, 0);
EXPECT_EQ(Traits::kCM0PerLane, 0);
EXPECT_EQ(Traits::kCM1PerLane, 0);
EXPECT_FALSE(Traits::IsMfma);
EXPECT_FALSE(Traits::IsWmma);
EXPECT_FALSE(Traits::IsSupported);
EXPECT_FALSE(MmaOpTraits<MmaOp>::IsMfma);
EXPECT_FALSE(MmaOpTraits<MmaOp>::IsWmma);
EXPECT_FALSE(MmaOpTraits<MmaOp>::IsSupported);
}
// Test MmaDefaultSelector for supported DummyAmdgcnMma
@@ -440,11 +311,11 @@ TEST(TestAmdgcnMma, MmaDefaultSelectorUnsupported)
EXPECT_FALSE(MmaOpTraits<SelectedMma>::IsSupported);
}
// Test MmaDefaultSelector for supported DummyAmdgcnMma on fragment sizes other than 16x16x16
// This tests that the selector can still pick the correct MMA op even if the fragment sizes differ
TEST(TestAmdgcnMma, MmaDefaultSelectorSupportedFragment)
// Test MmaDefaultSelector for supported DummyAmdgcnMma on WaveTile sizes other than 16x16x16
// This tests that the selector can still pick the correct MMA op even if the WaveTile sizes differ
TEST(TestAmdgcnMma, MmaDefaultSelectorSupportedWaveTile)
{
// Select indirectly with a fragment size of 256x128x64
// Select indirectly with a WaveTile size of 256x128x64
using SelectedMma = MmaDefaultSelector<fp32_t,
fp32_t,
fp32_t,
@@ -461,8 +332,8 @@ TEST(TestAmdgcnMma, MmaDefaultSelectorSupportedFragment)
EXPECT_TRUE(MmaOpTraits<SelectedMma>::IsSupported);
}
// Test MmaDefaultSelector for a different block size and supported arch
TEST(TestAmdgcnMma, MmaDefaultSelectorUnsupportedFragment)
// Test MmaDefaultSelector for a different WaveTile size and supported arch
TEST(TestAmdgcnMma, MmaDefaultSelectorUnsupportedWaveTile)
{
// This should fall back to unsupported since DummyAmdgcnMma only supports 16x16x16
using SelectedMma = MmaDefaultSelector<fp32_t,
@@ -496,36 +367,34 @@ TEST(TestAmdgcnMma, MmaDefaultSelectorFp16Unsupported)
// Test on real hardware for MmaOp selection.
// This is not a GEMM kernel, but a simple test to ensure that the selected MmaOp works correctly on
// real hardware. Assumption: inputs are all 1's The multiply-accumulate functionality can be tested
// here by looping over the k dimension and accumulating the results. They should be equal to FragK
// regardless of hardware.
// here by looping over the k dimension and accumulating the results. They should be equal to
// WaveTileK regardless of hardware.
template <typename ADataType,
typename BDataType,
typename CDataType,
uint32_t FragM,
uint32_t FragN,
uint32_t FragK>
uint32_t WaveTileM,
uint32_t WaveTileN,
uint32_t WaveTileK>
__global__ void test_accum_over_k(void* a, void* b, void* c, void* out)
{
using Selector = MmaDefaultSelector<ADataType,
BDataType,
CDataType,
FragM,
FragN,
FragK,
WaveTileM,
WaveTileN,
WaveTileK,
decltype(get_compiler_target()),
MmaOpFamily::DENSE>;
using MmaOp = typename Selector::SelectedOp;
using MmaTraits = MmaOpTraits<MmaOp>;
using MmaOp = typename Selector::SelectedOp;
using CVecType = typename MmaOp::CVecType;
static constexpr uint32_t kIters = FragK / MmaTraits::BlockK;
static constexpr uint32_t kIters = WaveTileK / MmaOp::kK;
// Initialize the accumulator
CVecType result = *reinterpret_cast<typename MmaOp::CVecType*>(c);
// Accumulate input AxB over FragK/BlockK iterations
// Accumulate input AxB over WaveTileK/FragK iterations
for(uint32_t i = 0; i < kIters; ++i)
{
result = MmaOp::exec(*reinterpret_cast<typename MmaOp::AVecType*>(a),
@@ -561,16 +430,16 @@ TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_16x16x32_Real)
using BType = fp16_t;
using CType = fp32_t;
// Fragment size, also the expected block size from the selector.
// Note: Actual blockK might be slightly different due to hardware implementation, but the
// WaveTile size, also the expected fragment size (MmaTile) from the selector.
// Note: Actual FragK might be slightly different due to hardware implementation, but the
// test_accum_over_k kernel will loop over the K dimension to ensure that the total K is
// correct.
static constexpr uint32_t FragM = 16;
static constexpr uint32_t FragN = 16;
static constexpr uint32_t FragK = 32;
static constexpr uint32_t BlockM = FragM;
static constexpr uint32_t BlockN = FragN;
static constexpr uint32_t BlockK = FragK;
static constexpr uint32_t WaveTileM = 16;
static constexpr uint32_t WaveTileN = 16;
static constexpr uint32_t WaveTileK = 32;
static constexpr uint32_t FragM = WaveTileM;
static constexpr uint32_t FragN = WaveTileN;
static constexpr uint32_t FragK = WaveTileK;
// Gfx11 has input data duplication and no accumulator padding (MultiplierC = 1)
// TODO: c++20 use is_target_family_gfx11(currentArchId)
@@ -581,9 +450,9 @@ TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_16x16x32_Real)
uint32_t MultiplierC = 1;
// The number of elements per thread
uint32_t AElements = BlockM * BlockK / deviceWarpSize * MultiplierA;
uint32_t BElements = BlockN * BlockK / deviceWarpSize * MultiplierB;
uint32_t CElements = BlockM * BlockN / deviceWarpSize * MultiplierC;
uint32_t AElements = FragM * FragK / deviceWarpSize * MultiplierA;
uint32_t BElements = FragN * FragK / deviceWarpSize * MultiplierB;
uint32_t CElements = FragM * FragN / deviceWarpSize * MultiplierC;
uint32_t ASize = AElements * sizeof(AType);
uint32_t BSize = BElements * sizeof(BType);
@@ -611,16 +480,16 @@ TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_16x16x32_Real)
HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice));
const auto wave_size = getDeviceWaveSize();
test_accum_over_k<AType, BType, CType, FragM, FragN, FragK>
test_accum_over_k<AType, BType, CType, WaveTileM, WaveTileN, WaveTileK>
<<<1, wave_size>>>(d_a, d_b, d_c, d_out);
HIP_CHECK_ERROR(hipDeviceSynchronize());
HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_out, CSize, hipMemcpyDeviceToHost));
// Output should be FragK for all elements, because the inputs are all 1's
// Output should be WaveTileK for all elements, because the inputs are all 1's
for(size_t i = 0; i < CElements; ++i)
{
CType expected = static_cast<CType>(FragK);
CType expected = static_cast<CType>(WaveTileK);
EXPECT_NEAR(h_out[i], expected, 1e-3);
}
@@ -633,7 +502,7 @@ TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_16x16x32_Real)
// Do a live test. At minimum, there should be a solution on real hardware for F16_F16_F32_16x16x32
// The selector should be able to pick the correct MmaOp as a multiple of 16x16x32, even if the
// fragment sizes are larger than 16x16x32. This tests that the selector can handle larger fragment
// WaveTile sizes are larger than 16x16x32. This tests that the selector can handle larger WaveTile
// sizes and still select the correct MmaOp.
TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_112x112x128_Real)
{
@@ -659,19 +528,19 @@ TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_112x112x128_Real)
using BType = fp16_t;
using CType = fp32_t;
// Fragment size to test for decomposition.
// We expect the selector to pick a 16x16 block
static constexpr uint32_t FragM = 112;
static constexpr uint32_t FragN = 112;
static constexpr uint32_t FragK = 128;
// WaveTile size to test for decomposition.
// We expect the selector to pick a 16x16 WaveTile
static constexpr uint32_t WaveTileM = 112;
static constexpr uint32_t WaveTileN = 112;
static constexpr uint32_t WaveTileK = 128;
// The expected block size from the selector (multiple of 16).
// Note: Actual blockK might be slightly different due to hardware implementation, but the
// The expected fragment size from the selector (MmaTile, multiple of 16).
// Note: Actual FragK might be slightly different due to hardware implementation, but the
// test_accum_over_k kernel will loop over the K dimension to ensure that the total K is
// correct.
static constexpr uint32_t BlockM = 16;
static constexpr uint32_t BlockN = 16;
static constexpr uint32_t BlockK = 32;
static constexpr uint32_t FragM = 16;
static constexpr uint32_t FragN = 16;
static constexpr uint32_t FragK = 32;
// Gfx11 has input data duplication and no accumulator padding (MultiplierC = 1)
// TODO: c++20 use is_target_family_gfx11(currentArchId)
@@ -682,9 +551,9 @@ TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_112x112x128_Real)
uint32_t MultiplierC = 1;
// The number of elements per thread
uint32_t AElements = BlockM * BlockK / deviceWarpSize * MultiplierA;
uint32_t BElements = BlockN * BlockK / deviceWarpSize * MultiplierB;
uint32_t CElements = BlockM * BlockN / deviceWarpSize * MultiplierC;
uint32_t AElements = FragM * FragK / deviceWarpSize * MultiplierA;
uint32_t BElements = FragN * FragK / deviceWarpSize * MultiplierB;
uint32_t CElements = FragM * FragN / deviceWarpSize * MultiplierC;
uint32_t ASize = AElements * sizeof(AType);
uint32_t BSize = BElements * sizeof(BType);
@@ -712,16 +581,16 @@ TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_112x112x128_Real)
HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice));
const auto wave_size = getDeviceWaveSize();
test_accum_over_k<AType, BType, CType, FragM, FragN, FragK>
test_accum_over_k<AType, BType, CType, WaveTileM, WaveTileN, WaveTileK>
<<<1, wave_size>>>(d_a, d_b, d_c, d_out);
HIP_CHECK_ERROR(hipDeviceSynchronize());
HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_out, CSize, hipMemcpyDeviceToHost));
// Output should be FragK for all elements, because the inputs are all 1's
// Output should be WaveTileK for all elements, because the inputs are all 1's
for(size_t i = 0; i < CElements; ++i)
{
CType expected = static_cast<CType>(FragK);
CType expected = static_cast<CType>(WaveTileK);
EXPECT_NEAR(h_out[i], expected, 1e-3);
}

View File

@@ -55,17 +55,17 @@ namespace {
* @tparam ADataType Data type of tensor A elements
* @tparam BDataType Data type of tensor B elements
* @tparam CDataType Data type of tensor C elements
* @tparam BlockM M-dimension of the MMA tile
* @tparam BlockN N-dimension of the MMA tile
* @tparam BlockK K-dimension of the MMA tile
* @tparam FragM M-dimension of the MMA tile
* @tparam FragN N-dimension of the MMA tile
* @tparam FragK K-dimension of the MMA tile
* @tparam BlockSize HIP block size
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
uint32_t BlockM,
uint32_t BlockN,
uint32_t BlockK,
uint32_t FragM,
uint32_t FragN,
uint32_t FragK,
uint32_t BlockSize>
struct MmaLayoutTestKernel
{
@@ -77,19 +77,18 @@ struct MmaLayoutTestKernel
mma::MmaDefaultSelector<ADataType,
BDataType,
CDataType,
BlockM,
BlockN,
BlockK,
FragM,
FragN,
FragK,
decltype(ck_tile::core::arch::get_compiler_target()),
mma::MmaOpFamily::DENSE>;
using MmaOp = typename Selector::SelectedOp;
using MmaTraits = mma::MmaOpTraits<MmaOp>;
using MmaOp = typename Selector::SelectedOp;
if constexpr(MmaTraits::IsSupported)
if constexpr(mma::MmaOpTraits<MmaOp>::IsSupported)
{
using AVecType = typename MmaTraits::AVecType;
using BVecType = typename MmaTraits::BVecType;
using CVecType = typename MmaTraits::CVecType;
using AVecType = typename MmaOp::AVecType;
using BVecType = typename MmaOp::BVecType;
using CVecType = typename MmaOp::CVecType;
constexpr uint32_t a_vec_size = vector_traits<AVecType>::vector_size;
constexpr uint32_t b_vec_size = vector_traits<BVecType>::vector_size;
constexpr uint32_t c_vec_size = vector_traits<CVecType>::vector_size;
@@ -102,9 +101,9 @@ struct MmaLayoutTestKernel
// get (m, k, n), where "1" should be placed for this block
const uint32_t case_idx = static_cast<uint32_t>(blockIdx.x);
const uint32_t m = case_idx / (MmaTraits::BlockK * MmaTraits::BlockN);
const uint32_t k = (case_idx / MmaTraits::BlockN) % MmaTraits::BlockK;
const uint32_t n = case_idx % MmaTraits::BlockN;
const uint32_t m = case_idx / (MmaOp::kK * MmaOp::kN);
const uint32_t k = (case_idx / MmaOp::kN) % MmaOp::kK;
const uint32_t n = case_idx % MmaOp::kN;
// place a single "1" in A/B fragments using (lane, vecIdx) -> (row, col) mapping
for(uint32_t v = 0; v < a_vec_size; ++v)
@@ -174,12 +173,12 @@ bool run_mma_layout_test()
{
using MmaOp = typename Selector::SelectedOp;
using MmaTraits = mma::MmaOpTraits<MmaOp>;
using ADataType = typename MmaTraits::ADataType;
using BDataType = typename MmaTraits::BDataType;
using CDataType = typename MmaTraits::CDataType;
constexpr uint32_t BlockM = MmaTraits::BlockM;
constexpr uint32_t BlockN = MmaTraits::BlockN;
constexpr uint32_t BlockK = MmaTraits::BlockK;
using ADataType = typename MmaOp::ADataType;
using BDataType = typename MmaOp::BDataType;
using CDataType = typename MmaOp::CDataType;
constexpr uint32_t FragM = MmaOp::kM;
constexpr uint32_t FragN = MmaOp::kN;
constexpr uint32_t FragK = MmaOp::kK;
constexpr auto selector_target_id = MmaTraits::CompilerTarget::TARGET_ID;
constexpr auto selector_wave_size = MmaTraits::CompilerTarget::WAVE_SIZE_ID;
@@ -202,7 +201,7 @@ bool run_mma_layout_test()
return false;
}
constexpr uint32_t total_cases = BlockM * BlockK * BlockN;
constexpr uint32_t total_cases = FragM * FragK * FragN;
ck_tile::DeviceMem d_errors(total_cases * sizeof(uint32_t));
std::vector<uint32_t> h_errors(total_cases, 0u);
@@ -213,9 +212,9 @@ bool run_mma_layout_test()
using Kernel = MmaLayoutTestKernel<ADataType,
BDataType,
CDataType,
BlockM,
BlockN,
BlockK,
FragM,
FragN,
FragK,
static_cast<int>(selector_wave_size)>;
std::ignore =
@@ -232,9 +231,9 @@ bool run_mma_layout_test()
for(uint32_t case_idx = 0; case_idx < total_cases; ++case_idx)
{
const uint32_t m = case_idx / (BlockK * BlockN);
const uint32_t k = (case_idx / BlockN) % BlockK;
const uint32_t n = case_idx % BlockN;
const uint32_t m = case_idx / (FragK * FragN);
const uint32_t k = (case_idx / FragN) % FragK;
const uint32_t n = case_idx % FragN;
EXPECT_EQ(h_errors[case_idx], 0u) << "Mismatch for m=" << m << " k=" << k << " n=" << n;
}

View File

@@ -6,7 +6,6 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
#include "ck_tile/core/arch/mma/mma_traits.hpp"
#include "ck_tile/core/arch/mma/mma_selector.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
@@ -93,12 +92,9 @@ struct RegisterMapTraits<ck_tile::core::arch::mma::amdgcn_mma<
CompilerTarget,
OpFamily_>;
using MmaTraits = ck_tile::core::arch::mma::MmaOpTraits<MmaOp>;
static constexpr index_t WaveSize =
static_cast<index_t>(MmaTraits::CompilerTarget::WAVE_SIZE_ID);
static constexpr index_t AVecSize = vector_traits<typename MmaTraits::AVecType>::vector_size;
static constexpr index_t BVecSize = vector_traits<typename MmaTraits::BVecType>::vector_size;
static constexpr index_t CVecSize = vector_traits<typename MmaTraits::CVecType>::vector_size;
static constexpr index_t AVecSize = vector_traits<typename MmaOp::AVecType>::vector_size;
static constexpr index_t BVecSize = vector_traits<typename MmaOp::BVecType>::vector_size;
static constexpr index_t CVecSize = vector_traits<typename MmaOp::CVecType>::vector_size;
using kABPs2RHssMajor = sequence<2, 1>;
using kABPs2RHssMinor = sequence<1, 0>;
@@ -176,12 +172,9 @@ struct RegisterMapTraits<ck_tile::core::arch::mma::amdgcn_mma<
CompilerTarget,
OpFamily_>;
using MmaTraits = ck_tile::core::arch::mma::MmaOpTraits<MmaOp>;
static constexpr index_t WaveSize =
static_cast<index_t>(MmaTraits::CompilerTarget::WAVE_SIZE_ID);
static constexpr index_t AVecSize = vector_traits<typename MmaTraits::AVecType>::vector_size;
static constexpr index_t BVecSize = vector_traits<typename MmaTraits::BVecType>::vector_size;
static constexpr index_t CVecSize = vector_traits<typename MmaTraits::CVecType>::vector_size;
static constexpr index_t AVecSize = vector_traits<typename MmaOp::AVecType>::vector_size;
static constexpr index_t BVecSize = vector_traits<typename MmaOp::BVecType>::vector_size;
static constexpr index_t CVecSize = vector_traits<typename MmaOp::CVecType>::vector_size;
using kABPs2RHssMajor = sequence<2, 1>;
using kABPs2RHssMinor = sequence<0, 0>;
@@ -192,29 +185,41 @@ struct RegisterMapTraits<ck_tile::core::arch::mma::amdgcn_mma<
using kCYs2RHsMajor = sequence<1>;
using kCYs2RHsMinor = sequence<1>;
using AWarpDstrEncoding = tile_distribution_encoding<
sequence<1>,
tuple<sequence<MmaOp::kAMLane>, sequence<MmaOp::kABKLane, MmaOp::kABKPerLane>>,
tuple<kABPs2RHssMajor>,
tuple<kABPs2RHssMinor>,
kABYs2RHsMajor,
kABYs2RHsMinor>;
// TODO: remove these and fix constants in amdgcn_mma
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 16;
static constexpr index_t kBNLane = 16;
static constexpr index_t kABKLane = 4;
static constexpr index_t kABKPerLane = 4;
static constexpr index_t kCMLane = 4;
static constexpr index_t kCNLane = 16;
static constexpr index_t kCM0PerLane = 1;
static constexpr index_t kCM1PerLane = 4;
using BWarpDstrEncoding = tile_distribution_encoding<
sequence<1>,
tuple<sequence<MmaOp::kBNLane>, sequence<MmaOp::kABKLane, MmaOp::kABKPerLane>>,
tuple<kABPs2RHssMajor>,
tuple<kABPs2RHssMinor>,
kABYs2RHsMajor,
kABYs2RHsMinor>;
using AWarpDstrEncoding =
tile_distribution_encoding<sequence<1>,
tuple<sequence<kAMLane>, sequence<kABKLane, kABKPerLane>>,
tuple<kABPs2RHssMajor>,
tuple<kABPs2RHssMinor>,
kABYs2RHsMajor,
kABYs2RHsMinor>;
using CWarpDstrEncoding = tile_distribution_encoding<
sequence<1>,
tuple<sequence<MmaOp::kCMLane, MmaOp::kCM1PerLane>, sequence<MmaOp::kCNLane>>,
tuple<kCPs2RHssMajor>,
tuple<kCPs2RHssMinor>,
kCYs2RHsMajor,
kCYs2RHsMinor>;
using BWarpDstrEncoding =
tile_distribution_encoding<sequence<1>,
tuple<sequence<kBNLane>, sequence<kABKLane, kABKPerLane>>,
tuple<kABPs2RHssMajor>,
tuple<kABPs2RHssMinor>,
kABYs2RHsMajor,
kABYs2RHsMinor>;
using CWarpDstrEncoding =
tile_distribution_encoding<sequence<1>,
tuple<sequence<kCMLane, kCM1PerLane>, sequence<kCNLane>>,
tuple<kCPs2RHssMajor>,
tuple<kCPs2RHssMinor>,
kCYs2RHsMajor,
kCYs2RHsMinor>;
};
/**
@@ -245,12 +250,9 @@ struct RegisterMapTraits<ck_tile::core::arch::mma::amdgcn_mma<
CompilerTarget,
OpFamily_>;
using MmaTraits = ck_tile::core::arch::mma::MmaOpTraits<MmaOp>;
static constexpr index_t WaveSize =
static_cast<index_t>(MmaTraits::CompilerTarget::WAVE_SIZE_ID);
static constexpr index_t AVecSize = vector_traits<typename MmaTraits::AVecType>::vector_size;
static constexpr index_t BVecSize = vector_traits<typename MmaTraits::BVecType>::vector_size;
static constexpr index_t CVecSize = vector_traits<typename MmaTraits::CVecType>::vector_size;
static constexpr index_t AVecSize = vector_traits<typename MmaOp::AVecType>::vector_size;
static constexpr index_t BVecSize = vector_traits<typename MmaOp::BVecType>::vector_size;
static constexpr index_t CVecSize = vector_traits<typename MmaOp::CVecType>::vector_size;
using kABPs2RHssMajor = sequence<0, 1>;
using kABPs2RHssMinor = sequence<0, 0>;

View File

@@ -144,32 +144,29 @@ TEST(SparseMMATrait, SparseSelector)
template <typename AType,
typename BType,
typename CType,
uint32_t FragM,
uint32_t FragN,
uint32_t FragK>
uint32_t WaveTileM,
uint32_t WaveTileN,
uint32_t WaveTileK>
__global__ void test_sparse_accum_over_k(void* a, void* b, void* c, void* out)
{
using CompilerTarget = decltype(get_compiler_target());
using Selector = MmaDefaultSelector<AType,
BType,
CType,
FragM,
FragN,
FragK,
WaveTileM,
WaveTileN,
WaveTileK,
CompilerTarget,
MmaOpFamily::SPARSE>;
using MmaOp = typename Selector::SelectedOp;
using CVecType = typename MmaOp::CVecType;
using MmaOp = typename Selector::SelectedOp;
using MmaTraits = MmaOpTraits<MmaOp>;
using CVecType = typename MmaOp::CVecType;
static constexpr uint32_t kIters = FragK / MmaTraits::BlockK;
static constexpr uint32_t kIters = WaveTileK / MmaOp::kK;
// Initialize the accumulator
CVecType result = *reinterpret_cast<typename MmaOp::CVecType*>(c);
// Accumulate input AxB over FragK/BlockK iterations
// Accumulate input AxB over WaveTileK/FragK iterations
for(uint32_t i = 0; i < kIters; ++i)
{
result = MmaOp::exec(*reinterpret_cast<typename MmaOp::AVecType*>(a),
@@ -210,21 +207,21 @@ TEST(SparseMMATrait, MmaSelector_Sparse_F16_F16_F32_16x16x32_Real)
using BType = fp16_t;
using CType = fp32_t;
// Fragment size, also the expected block size from the selector.
// Note: Actual blockK might be slightly different due to hardware implementation, but the
// WaveTile size, also the expected fragment size (MmaTile) from the selector.
// Note: Actual FragK might be slightly different due to hardware implementation, but the
// test_accum_over_k kernel will loop over the K dimension to ensure that the total K is
// correct.
static constexpr uint32_t FragM = 16;
static constexpr uint32_t FragN = 16;
static constexpr uint32_t FragK = 32;
static constexpr uint32_t BlockM = FragM;
static constexpr uint32_t BlockN = FragN;
static constexpr uint32_t BlockK = FragK;
static constexpr uint32_t WaveTileM = 16;
static constexpr uint32_t WaveTileN = 16;
static constexpr uint32_t WaveTileK = 32;
static constexpr uint32_t FragM = WaveTileM;
static constexpr uint32_t FragN = WaveTileN;
static constexpr uint32_t FragK = WaveTileK;
// The number of elements per thread
uint32_t AElements = BlockM * BlockK / deviceWarpSize;
uint32_t BElements = BlockN * BlockK / deviceWarpSize;
uint32_t CElements = BlockM * BlockN / deviceWarpSize;
uint32_t AElements = FragM * FragK / deviceWarpSize;
uint32_t BElements = FragN * FragK / deviceWarpSize;
uint32_t CElements = FragM * FragN / deviceWarpSize;
uint32_t ASize = AElements * sizeof(AType);
uint32_t BSize = BElements * sizeof(BType);