mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
[CK_Tile] Refactor amdgcn_mma policy structs (#5272)
## Motivation The point of this MR is to update the intrinsic layout parameters to simplify them and make them more clear and flexible. Also, a number of simple refactors were performed to reduce boilerplate and code duplication. ## Technical Details In CK Tile and old CK, the full set of information available in the intrinsic wrappers, for WMMA and MFMA combined, would be something like: ``` // Basic info using ADataType = void; using BDataType = void; using CDataType = void; using AVecType = ext_vector_t<ADataType, 0>; using BVecType = ext_vector_t<BDataType, 0>; using CVecType = ext_vector_t<CDataType, 0>; // Fragment sizes static constexpr index_t kM; static constexpr index_t kN; static constexpr index_t kK; // Layout parameters static constexpr index_t kAMBlock; static constexpr index_t kBNBlock; static constexpr index_t kRepeat; static constexpr index_t kAMLane; static constexpr index_t kBNLane; static constexpr index_t kABK0PerLane; static constexpr index_t kABKLane; static constexpr index_t kABK1PerLane; static constexpr index_t kCMLane; static constexpr index_t kCNLane; static constexpr index_t kCM0PerLane; static constexpr index_t kCM1PerLane; using kABPs2RHssMajor = sequence<2, 1>; using kABPs2RHssMinor = sequence<1, 0>; using kABYs2RHsMajor = sequence<2, 2>; using kABYs2RHsMinor = sequence<0, 2>; using kCPs2RHssMajor = sequence<1, 2>; using kCPs2RHssMinor = sequence<1, 0>; using kCYs2RHsMajor = sequence<1, 1>; using kCYs2RHsMinor = sequence<0, 2>; using kCTPs2RHssMajor = sequence<2, 1>; using kCTPs2RHssMinor = sequence<1, 0>; using kCTYs2RHsMajor = sequence<2, 2>; using kCTYs2RHsMinor = sequence<0, 2>; ``` Note that on top of the intrinsic sizes, we have 12 layout parameters. I have reduced this in the new design to: ``` // Basic info using ADataType = void; using BDataType = void; using CDataType = void; // Fragment sizes static constexpr index_t kM; static constexpr index_t kN; static constexpr index_t kK; // Layout parameters static constexpr index_t kABKPerLane; // K2 * K0, Always the same, even for diff A / B layouts static constexpr index_t kAKNumAccess; // K2 static constexpr index_t kARepeat; // Used for RDNA3 repeated inputs and CDNA block hiding. static constexpr index_t kBKNumAccess; // K2 static constexpr index_t kBRepeat; // Used for RDNA3 repeated inputs and CDNA block hiding. static constexpr index_t kCMPerLane; // M2 * M0 static constexpr index_t kCMNumAccess; // M2 // Derived properties using AVecType = ext_vector_t<ADataType, 0>; using BVecType = ext_vector_t<BDataType, 0>; using CVecType = ext_vector_t<CDataType, 0>; ``` Note that there are now only 7 layout parameters and no more dimensionality orderings. Believe it or not these 7 parameters are more general than the original 12, and can handle intrinsic and mid-level features that are currently awkward in CK Tile, like dealing with AttrNumAccess, different A / B layouts, more general block-hiding (currently very limited in CK tile), and future arch features. Furthermore, the A, B and C vec types are now derived directly from the layout parameters to ensure internal consistency. I added a detailed explanation of the new params in terms of register mappings at the top of amgcn_mma.hpp Other refactorings I did in this MR: - Make an amdgcn_mma_base struct to drastically reduce code duplication and potential bugs. Should also make auto-generating the amd_gcn specializations much easier. - Simplify the MmaOpTraits significantly by only including those parameters that are not directly gettable from the MmaOp itself. This removes duplicated variables and simplifies higher level code. - Remove overloaded "Block" term for intrinsic dimensions, and replace by "Frag" instead. Some spots were already using the term "Frag" for combined intrinsics, in which case I changed that term to "Chunk" instead. - Remove some tests that had become somewhat pointless (setting variables and then checking their values immediately). - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
GitHub
parent
0a3229ea22
commit
e8f9bb0c19
@@ -41,40 +41,12 @@ using enable_if_target_id_dummy_t = std::enable_if_t<is_dummy_target(CompilerTar
|
||||
// and can focus on testing the mechanism of selecting supported vs unsupported architectures.
|
||||
// TODO: c++20 template <amdgcn_target_arch_id CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
struct amdgcn_mma<fp32_t,
|
||||
fp32_t,
|
||||
fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
DummyCtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::DENSE,
|
||||
enable_if_target_id_dummy_t<CompilerTarget>>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp32_t, fp32_t, fp32_t, 8u, 8u, 8u, DummyCtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_id_dummy_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<fp32_t, fp32_t, fp32_t, 8u, 8u, 8u, 64u, 1, 1, 1, 1, 1, 1, 1, DummyOpType, MmaOpFamily::DENSE>
|
||||
// clang-format on
|
||||
{
|
||||
// Mfma operation type
|
||||
using OpType = DummyOpType;
|
||||
static constexpr MmaOpFamily OpFamily = MmaOpFamily::DENSE;
|
||||
|
||||
// Register types
|
||||
using AVecType = ext_vector_t<fp32_t, 4>;
|
||||
using BVecType = ext_vector_t<fp32_t, 4>;
|
||||
using CVecType = ext_vector_t<fp32_t, 4>;
|
||||
|
||||
// Layout constants
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 2;
|
||||
|
||||
static constexpr index_t kAMLane = 3;
|
||||
static constexpr index_t kBNLane = 4;
|
||||
static constexpr index_t kABKLane = 5;
|
||||
static constexpr index_t kABKPerLane = 6;
|
||||
|
||||
static constexpr index_t kCMLane = 7;
|
||||
static constexpr index_t kCNLane = 8;
|
||||
static constexpr index_t kCM0PerLane = 9;
|
||||
static constexpr index_t kCM1PerLane = 10;
|
||||
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& regsA, BVecType const& regsB, CVecType const& regsC)
|
||||
{
|
||||
@@ -88,30 +60,30 @@ template <typename CompilerTarget>
|
||||
using DummyAmdgcnMma = amdgcn_mma<fp32_t,
|
||||
fp32_t,
|
||||
fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
8u,
|
||||
8u,
|
||||
8u,
|
||||
DummyCtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::DENSE>;
|
||||
|
||||
/*! @struct MmaDefaultSelector
|
||||
* @brief For dummy Id only, instantiate tests for both MFMA and WMMA selectors so we can them both
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam FragM Size of the M dimension of the fragment to decompose
|
||||
* @tparam FragN Size of the N dimension of the fragment to decompose
|
||||
* @tparam FragK Size of the K dimension of the fragment to decompose
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam WaveTileM Size of the M dimension of the WaveTile to decompose
|
||||
* @tparam WaveTileN Size of the N dimension of the WaveTile to decompose
|
||||
* @tparam WaveTileK Size of the K dimension of the WaveTile to decompose
|
||||
* @tparam CompilerTarget The compiler target
|
||||
* @tparam OpFamily The MMA operation family
|
||||
* @tparam OpFamily The MMA operation family
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t FragM,
|
||||
uint32_t FragN,
|
||||
uint32_t FragK,
|
||||
uint32_t WaveTileM,
|
||||
uint32_t WaveTileN,
|
||||
uint32_t WaveTileK,
|
||||
typename CompilerTarget,
|
||||
MmaOpFamily OpFamily>
|
||||
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
@@ -119,9 +91,9 @@ template <typename ADataType,
|
||||
struct MmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
FragM,
|
||||
FragN,
|
||||
FragK,
|
||||
WaveTileM,
|
||||
WaveTileN,
|
||||
WaveTileK,
|
||||
CompilerTarget,
|
||||
OpFamily,
|
||||
enable_if_all<enable_if_target_id_dummy_t<CompilerTarget>,
|
||||
@@ -145,25 +117,6 @@ TEST(TestAmdgcnMma, ArchSupported)
|
||||
(std::is_same<typename MmaOp::OpType, DummyOpType>::value)); // OpType is DummyOpType
|
||||
// Check OpFamily
|
||||
EXPECT_TRUE((is_mma_op_of_family_v<MmaOpFamily::DENSE, MmaOp>));
|
||||
|
||||
// Check AVecType, BVecType, CVecType
|
||||
EXPECT_TRUE((std::is_same<typename MmaOp::AVecType, ext_vector_t<fp32_t, 4>>::value));
|
||||
EXPECT_TRUE((std::is_same<typename MmaOp::BVecType, ext_vector_t<fp32_t, 4>>::value));
|
||||
EXPECT_TRUE((std::is_same<typename MmaOp::CVecType, ext_vector_t<fp32_t, 4>>::value));
|
||||
|
||||
// Check layout constants
|
||||
EXPECT_EQ(MmaOp::kAMBlock, 1);
|
||||
EXPECT_EQ(MmaOp::kBNBlock, 2);
|
||||
|
||||
EXPECT_EQ(MmaOp::kAMLane, 3);
|
||||
EXPECT_EQ(MmaOp::kBNLane, 4);
|
||||
EXPECT_EQ(MmaOp::kABKLane, 5);
|
||||
EXPECT_EQ(MmaOp::kABKPerLane, 6);
|
||||
|
||||
EXPECT_EQ(MmaOp::kCMLane, 7);
|
||||
EXPECT_EQ(MmaOp::kCNLane, 8);
|
||||
EXPECT_EQ(MmaOp::kCM0PerLane, 9);
|
||||
EXPECT_EQ(MmaOp::kCM1PerLane, 10);
|
||||
}
|
||||
|
||||
// Test case for unsupported architecture
|
||||
@@ -176,25 +129,6 @@ TEST(TestAmdgcnMma, ArchUnsupported)
|
||||
EXPECT_TRUE((std::is_same<typename MmaOp::OpType, Unsupported>::value));
|
||||
// OpFamily should be Undefined
|
||||
EXPECT_TRUE((is_mma_op_of_family_v<MmaOpFamily::UNDEFINED, MmaOp>));
|
||||
|
||||
// AVecType, BVecType, CVecType should match default
|
||||
EXPECT_TRUE((std::is_same<typename MmaOp::AVecType, ext_vector_t<fp32_t, 1>>::value));
|
||||
EXPECT_TRUE((std::is_same<typename MmaOp::BVecType, ext_vector_t<fp32_t, 1>>::value));
|
||||
EXPECT_TRUE((std::is_same<typename MmaOp::CVecType, ext_vector_t<fp32_t, 1>>::value));
|
||||
|
||||
// Layout constants should match default values (typically 0)
|
||||
EXPECT_EQ(MmaOp::kAMBlock, 0);
|
||||
EXPECT_EQ(MmaOp::kBNBlock, 0);
|
||||
|
||||
EXPECT_EQ(MmaOp::kAMLane, 0);
|
||||
EXPECT_EQ(MmaOp::kBNLane, 0);
|
||||
EXPECT_EQ(MmaOp::kABKLane, 0);
|
||||
EXPECT_EQ(MmaOp::kABKPerLane, 0);
|
||||
|
||||
EXPECT_EQ(MmaOp::kCMLane, 0);
|
||||
EXPECT_EQ(MmaOp::kCNLane, 0);
|
||||
EXPECT_EQ(MmaOp::kCM0PerLane, 0);
|
||||
EXPECT_EQ(MmaOp::kCM1PerLane, 0);
|
||||
}
|
||||
|
||||
// Kernel to test amdgcn_mma::exec on device
|
||||
@@ -317,89 +251,26 @@ TEST(TestAmdgcnMma, ArchUnsupportedExecDeviceOutput)
|
||||
|
||||
#include "ck_tile/core/arch/mma/mma_traits.hpp"
|
||||
|
||||
// Test MmaOpParams for supported DummyAmdgcnMma, including all member variables
|
||||
TEST(TestAmdgcnMma, MmaOpParamsTraitsSupportedMembers)
|
||||
{
|
||||
using MmaOp = DummyAmdgcnMma<DummyCompilerTarget>;
|
||||
using Traits = MmaOpParams<MmaOp>;
|
||||
|
||||
// Check MmaOpParams members
|
||||
EXPECT_TRUE((std::is_same<typename Traits::ADataType, fp32_t>::value));
|
||||
EXPECT_TRUE((std::is_same<typename Traits::BDataType, fp32_t>::value));
|
||||
EXPECT_TRUE((std::is_same<typename Traits::CDataType, fp32_t>::value));
|
||||
EXPECT_EQ(Traits::BlockM, 16u);
|
||||
EXPECT_EQ(Traits::BlockN, 16u);
|
||||
EXPECT_EQ(Traits::BlockK, 16u);
|
||||
EXPECT_TRUE((std::is_same<typename Traits::CtrlFlags, DummyCtrlFlags>::value));
|
||||
}
|
||||
|
||||
// Test MmaOpParams for unsupported DummyAmdgcnMma, including all member variables
|
||||
TEST(TestAmdgcnMma, MmaOpParamsUnsupportedMembers)
|
||||
{
|
||||
using MmaOp = DummyAmdgcnMma<amdgcn_target<>>;
|
||||
using Traits = MmaOpParams<MmaOp>;
|
||||
|
||||
// Check MmaOpParams members
|
||||
EXPECT_TRUE((std::is_same<typename Traits::ADataType, fp32_t>::value));
|
||||
EXPECT_TRUE((std::is_same<typename Traits::BDataType, fp32_t>::value));
|
||||
EXPECT_TRUE((std::is_same<typename Traits::CDataType, fp32_t>::value));
|
||||
EXPECT_EQ(Traits::BlockM, 16u);
|
||||
EXPECT_EQ(Traits::BlockN, 16u);
|
||||
EXPECT_EQ(Traits::BlockK, 16u);
|
||||
EXPECT_TRUE((std::is_same<typename Traits::CtrlFlags, DummyCtrlFlags>::value));
|
||||
}
|
||||
|
||||
// Test MmaOpTraits for supported DummyAmdgcnMma, including all member variables
|
||||
TEST(TestAmdgcnMma, MmaOpTraitsSupportedMembers)
|
||||
{
|
||||
using MmaOp = DummyAmdgcnMma<DummyCompilerTarget>;
|
||||
using Traits = MmaOpTraits<MmaOp>;
|
||||
using MmaOp = DummyAmdgcnMma<DummyCompilerTarget>;
|
||||
|
||||
// Check MmaOpTraits member variables
|
||||
EXPECT_TRUE((std::is_same<typename Traits::OpType, DummyOpType>::value));
|
||||
EXPECT_TRUE((std::is_same<typename Traits::AVecType, ext_vector_t<fp32_t, 4>>::value));
|
||||
EXPECT_TRUE((std::is_same<typename Traits::BVecType, ext_vector_t<fp32_t, 4>>::value));
|
||||
EXPECT_TRUE((std::is_same<typename Traits::CVecType, ext_vector_t<fp32_t, 4>>::value));
|
||||
EXPECT_EQ(Traits::kAMBlock, 1);
|
||||
EXPECT_EQ(Traits::kBNBlock, 2);
|
||||
EXPECT_EQ(Traits::kAMLane, 3);
|
||||
EXPECT_EQ(Traits::kBNLane, 4);
|
||||
EXPECT_EQ(Traits::kABKLane, 5);
|
||||
EXPECT_EQ(Traits::kABKPerLane, 6);
|
||||
EXPECT_EQ(Traits::kCMLane, 7);
|
||||
EXPECT_EQ(Traits::kCNLane, 8);
|
||||
EXPECT_EQ(Traits::kCM0PerLane, 9);
|
||||
EXPECT_EQ(Traits::kCM1PerLane, 10);
|
||||
EXPECT_FALSE(Traits::IsMfma);
|
||||
EXPECT_FALSE(Traits::IsWmma);
|
||||
EXPECT_TRUE(Traits::IsSupported);
|
||||
EXPECT_FALSE(MmaOpTraits<MmaOp>::IsMfma);
|
||||
EXPECT_FALSE(MmaOpTraits<MmaOp>::IsWmma);
|
||||
EXPECT_TRUE(MmaOpTraits<MmaOp>::IsSupported);
|
||||
}
|
||||
|
||||
// Test MmaOpTraits for unsupported DummyAmdgcnMma, including all member variables
|
||||
TEST(TestAmdgcnMma, MmaOpTraitsUnsupportedMembers)
|
||||
{
|
||||
using MmaOp = DummyAmdgcnMma<amdgcn_target<>>;
|
||||
using Traits = MmaOpTraits<MmaOp>;
|
||||
using MmaOp = DummyAmdgcnMma<amdgcn_target<>>;
|
||||
|
||||
// Check MmaOpTraits member variables
|
||||
EXPECT_TRUE((std::is_same<typename Traits::OpType, Unsupported>::value));
|
||||
EXPECT_TRUE((std::is_same<typename Traits::AVecType, ext_vector_t<fp32_t, 1>>::value));
|
||||
EXPECT_TRUE((std::is_same<typename Traits::BVecType, ext_vector_t<fp32_t, 1>>::value));
|
||||
EXPECT_TRUE((std::is_same<typename Traits::CVecType, ext_vector_t<fp32_t, 1>>::value));
|
||||
EXPECT_EQ(Traits::OpFamily, MmaOpFamily::UNDEFINED);
|
||||
EXPECT_EQ(Traits::kAMBlock, 0);
|
||||
EXPECT_EQ(Traits::kBNBlock, 0);
|
||||
EXPECT_EQ(Traits::kAMLane, 0);
|
||||
EXPECT_EQ(Traits::kBNLane, 0);
|
||||
EXPECT_EQ(Traits::kABKLane, 0);
|
||||
EXPECT_EQ(Traits::kABKPerLane, 0);
|
||||
EXPECT_EQ(Traits::kCMLane, 0);
|
||||
EXPECT_EQ(Traits::kCNLane, 0);
|
||||
EXPECT_EQ(Traits::kCM0PerLane, 0);
|
||||
EXPECT_EQ(Traits::kCM1PerLane, 0);
|
||||
EXPECT_FALSE(Traits::IsMfma);
|
||||
EXPECT_FALSE(Traits::IsWmma);
|
||||
EXPECT_FALSE(Traits::IsSupported);
|
||||
EXPECT_FALSE(MmaOpTraits<MmaOp>::IsMfma);
|
||||
EXPECT_FALSE(MmaOpTraits<MmaOp>::IsWmma);
|
||||
EXPECT_FALSE(MmaOpTraits<MmaOp>::IsSupported);
|
||||
}
|
||||
|
||||
// Test MmaDefaultSelector for supported DummyAmdgcnMma
|
||||
@@ -440,11 +311,11 @@ TEST(TestAmdgcnMma, MmaDefaultSelectorUnsupported)
|
||||
EXPECT_FALSE(MmaOpTraits<SelectedMma>::IsSupported);
|
||||
}
|
||||
|
||||
// Test MmaDefaultSelector for supported DummyAmdgcnMma on fragment sizes other than 16x16x16
|
||||
// This tests that the selector can still pick the correct MMA op even if the fragment sizes differ
|
||||
TEST(TestAmdgcnMma, MmaDefaultSelectorSupportedFragment)
|
||||
// Test MmaDefaultSelector for supported DummyAmdgcnMma on WaveTile sizes other than 16x16x16
|
||||
// This tests that the selector can still pick the correct MMA op even if the WaveTile sizes differ
|
||||
TEST(TestAmdgcnMma, MmaDefaultSelectorSupportedWaveTile)
|
||||
{
|
||||
// Select indirectly with a fragment size of 256x128x64
|
||||
// Select indirectly with a WaveTile size of 256x128x64
|
||||
using SelectedMma = MmaDefaultSelector<fp32_t,
|
||||
fp32_t,
|
||||
fp32_t,
|
||||
@@ -461,8 +332,8 @@ TEST(TestAmdgcnMma, MmaDefaultSelectorSupportedFragment)
|
||||
EXPECT_TRUE(MmaOpTraits<SelectedMma>::IsSupported);
|
||||
}
|
||||
|
||||
// Test MmaDefaultSelector for a different block size and supported arch
|
||||
TEST(TestAmdgcnMma, MmaDefaultSelectorUnsupportedFragment)
|
||||
// Test MmaDefaultSelector for a different WaveTile size and supported arch
|
||||
TEST(TestAmdgcnMma, MmaDefaultSelectorUnsupportedWaveTile)
|
||||
{
|
||||
// This should fall back to unsupported since DummyAmdgcnMma only supports 16x16x16
|
||||
using SelectedMma = MmaDefaultSelector<fp32_t,
|
||||
@@ -496,36 +367,34 @@ TEST(TestAmdgcnMma, MmaDefaultSelectorFp16Unsupported)
|
||||
// Test on real hardware for MmaOp selection.
|
||||
// This is not a GEMM kernel, but a simple test to ensure that the selected MmaOp works correctly on
|
||||
// real hardware. Assumption: inputs are all 1's The multiply-accumulate functionality can be tested
|
||||
// here by looping over the k dimension and accumulating the results. They should be equal to FragK
|
||||
// regardless of hardware.
|
||||
// here by looping over the k dimension and accumulating the results. They should be equal to
|
||||
// WaveTileK regardless of hardware.
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t FragM,
|
||||
uint32_t FragN,
|
||||
uint32_t FragK>
|
||||
uint32_t WaveTileM,
|
||||
uint32_t WaveTileN,
|
||||
uint32_t WaveTileK>
|
||||
__global__ void test_accum_over_k(void* a, void* b, void* c, void* out)
|
||||
{
|
||||
using Selector = MmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
FragM,
|
||||
FragN,
|
||||
FragK,
|
||||
WaveTileM,
|
||||
WaveTileN,
|
||||
WaveTileK,
|
||||
decltype(get_compiler_target()),
|
||||
MmaOpFamily::DENSE>;
|
||||
|
||||
using MmaOp = typename Selector::SelectedOp;
|
||||
using MmaTraits = MmaOpTraits<MmaOp>;
|
||||
|
||||
using MmaOp = typename Selector::SelectedOp;
|
||||
using CVecType = typename MmaOp::CVecType;
|
||||
|
||||
static constexpr uint32_t kIters = FragK / MmaTraits::BlockK;
|
||||
static constexpr uint32_t kIters = WaveTileK / MmaOp::kK;
|
||||
|
||||
// Initialize the accumulator
|
||||
CVecType result = *reinterpret_cast<typename MmaOp::CVecType*>(c);
|
||||
|
||||
// Accumulate input AxB over FragK/BlockK iterations
|
||||
// Accumulate input AxB over WaveTileK/FragK iterations
|
||||
for(uint32_t i = 0; i < kIters; ++i)
|
||||
{
|
||||
result = MmaOp::exec(*reinterpret_cast<typename MmaOp::AVecType*>(a),
|
||||
@@ -561,16 +430,16 @@ TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_16x16x32_Real)
|
||||
using BType = fp16_t;
|
||||
using CType = fp32_t;
|
||||
|
||||
// Fragment size, also the expected block size from the selector.
|
||||
// Note: Actual blockK might be slightly different due to hardware implementation, but the
|
||||
// WaveTile size, also the expected fragment size (MmaTile) from the selector.
|
||||
// Note: Actual FragK might be slightly different due to hardware implementation, but the
|
||||
// test_accum_over_k kernel will loop over the K dimension to ensure that the total K is
|
||||
// correct.
|
||||
static constexpr uint32_t FragM = 16;
|
||||
static constexpr uint32_t FragN = 16;
|
||||
static constexpr uint32_t FragK = 32;
|
||||
static constexpr uint32_t BlockM = FragM;
|
||||
static constexpr uint32_t BlockN = FragN;
|
||||
static constexpr uint32_t BlockK = FragK;
|
||||
static constexpr uint32_t WaveTileM = 16;
|
||||
static constexpr uint32_t WaveTileN = 16;
|
||||
static constexpr uint32_t WaveTileK = 32;
|
||||
static constexpr uint32_t FragM = WaveTileM;
|
||||
static constexpr uint32_t FragN = WaveTileN;
|
||||
static constexpr uint32_t FragK = WaveTileK;
|
||||
|
||||
// Gfx11 has input data duplication and no accumulator padding (MultiplierC = 1)
|
||||
// TODO: c++20 use is_target_family_gfx11(currentArchId)
|
||||
@@ -581,9 +450,9 @@ TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_16x16x32_Real)
|
||||
uint32_t MultiplierC = 1;
|
||||
|
||||
// The number of elements per thread
|
||||
uint32_t AElements = BlockM * BlockK / deviceWarpSize * MultiplierA;
|
||||
uint32_t BElements = BlockN * BlockK / deviceWarpSize * MultiplierB;
|
||||
uint32_t CElements = BlockM * BlockN / deviceWarpSize * MultiplierC;
|
||||
uint32_t AElements = FragM * FragK / deviceWarpSize * MultiplierA;
|
||||
uint32_t BElements = FragN * FragK / deviceWarpSize * MultiplierB;
|
||||
uint32_t CElements = FragM * FragN / deviceWarpSize * MultiplierC;
|
||||
|
||||
uint32_t ASize = AElements * sizeof(AType);
|
||||
uint32_t BSize = BElements * sizeof(BType);
|
||||
@@ -611,16 +480,16 @@ TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_16x16x32_Real)
|
||||
HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice));
|
||||
|
||||
const auto wave_size = getDeviceWaveSize();
|
||||
test_accum_over_k<AType, BType, CType, FragM, FragN, FragK>
|
||||
test_accum_over_k<AType, BType, CType, WaveTileM, WaveTileN, WaveTileK>
|
||||
<<<1, wave_size>>>(d_a, d_b, d_c, d_out);
|
||||
HIP_CHECK_ERROR(hipDeviceSynchronize());
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_out, CSize, hipMemcpyDeviceToHost));
|
||||
|
||||
// Output should be FragK for all elements, because the inputs are all 1's
|
||||
// Output should be WaveTileK for all elements, because the inputs are all 1's
|
||||
for(size_t i = 0; i < CElements; ++i)
|
||||
{
|
||||
CType expected = static_cast<CType>(FragK);
|
||||
CType expected = static_cast<CType>(WaveTileK);
|
||||
|
||||
EXPECT_NEAR(h_out[i], expected, 1e-3);
|
||||
}
|
||||
@@ -633,7 +502,7 @@ TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_16x16x32_Real)
|
||||
|
||||
// Do a live test. At minimum, there should be a solution on real hardware for F16_F16_F32_16x16x32
|
||||
// The selector should be able to pick the correct MmaOp as a multiple of 16x16x32, even if the
|
||||
// fragment sizes are larger than 16x16x32. This tests that the selector can handle larger fragment
|
||||
// WaveTile sizes are larger than 16x16x32. This tests that the selector can handle larger WaveTile
|
||||
// sizes and still select the correct MmaOp.
|
||||
TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_112x112x128_Real)
|
||||
{
|
||||
@@ -659,19 +528,19 @@ TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_112x112x128_Real)
|
||||
using BType = fp16_t;
|
||||
using CType = fp32_t;
|
||||
|
||||
// Fragment size to test for decomposition.
|
||||
// We expect the selector to pick a 16x16 block
|
||||
static constexpr uint32_t FragM = 112;
|
||||
static constexpr uint32_t FragN = 112;
|
||||
static constexpr uint32_t FragK = 128;
|
||||
// WaveTile size to test for decomposition.
|
||||
// We expect the selector to pick a 16x16 WaveTile
|
||||
static constexpr uint32_t WaveTileM = 112;
|
||||
static constexpr uint32_t WaveTileN = 112;
|
||||
static constexpr uint32_t WaveTileK = 128;
|
||||
|
||||
// The expected block size from the selector (multiple of 16).
|
||||
// Note: Actual blockK might be slightly different due to hardware implementation, but the
|
||||
// The expected fragment size from the selector (MmaTile, multiple of 16).
|
||||
// Note: Actual FragK might be slightly different due to hardware implementation, but the
|
||||
// test_accum_over_k kernel will loop over the K dimension to ensure that the total K is
|
||||
// correct.
|
||||
static constexpr uint32_t BlockM = 16;
|
||||
static constexpr uint32_t BlockN = 16;
|
||||
static constexpr uint32_t BlockK = 32;
|
||||
static constexpr uint32_t FragM = 16;
|
||||
static constexpr uint32_t FragN = 16;
|
||||
static constexpr uint32_t FragK = 32;
|
||||
|
||||
// Gfx11 has input data duplication and no accumulator padding (MultiplierC = 1)
|
||||
// TODO: c++20 use is_target_family_gfx11(currentArchId)
|
||||
@@ -682,9 +551,9 @@ TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_112x112x128_Real)
|
||||
uint32_t MultiplierC = 1;
|
||||
|
||||
// The number of elements per thread
|
||||
uint32_t AElements = BlockM * BlockK / deviceWarpSize * MultiplierA;
|
||||
uint32_t BElements = BlockN * BlockK / deviceWarpSize * MultiplierB;
|
||||
uint32_t CElements = BlockM * BlockN / deviceWarpSize * MultiplierC;
|
||||
uint32_t AElements = FragM * FragK / deviceWarpSize * MultiplierA;
|
||||
uint32_t BElements = FragN * FragK / deviceWarpSize * MultiplierB;
|
||||
uint32_t CElements = FragM * FragN / deviceWarpSize * MultiplierC;
|
||||
|
||||
uint32_t ASize = AElements * sizeof(AType);
|
||||
uint32_t BSize = BElements * sizeof(BType);
|
||||
@@ -712,16 +581,16 @@ TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_112x112x128_Real)
|
||||
HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice));
|
||||
|
||||
const auto wave_size = getDeviceWaveSize();
|
||||
test_accum_over_k<AType, BType, CType, FragM, FragN, FragK>
|
||||
test_accum_over_k<AType, BType, CType, WaveTileM, WaveTileN, WaveTileK>
|
||||
<<<1, wave_size>>>(d_a, d_b, d_c, d_out);
|
||||
HIP_CHECK_ERROR(hipDeviceSynchronize());
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_out, CSize, hipMemcpyDeviceToHost));
|
||||
|
||||
// Output should be FragK for all elements, because the inputs are all 1's
|
||||
// Output should be WaveTileK for all elements, because the inputs are all 1's
|
||||
for(size_t i = 0; i < CElements; ++i)
|
||||
{
|
||||
CType expected = static_cast<CType>(FragK);
|
||||
CType expected = static_cast<CType>(WaveTileK);
|
||||
|
||||
EXPECT_NEAR(h_out[i], expected, 1e-3);
|
||||
}
|
||||
|
||||
@@ -55,17 +55,17 @@ namespace {
|
||||
* @tparam ADataType Data type of tensor A elements
|
||||
* @tparam BDataType Data type of tensor B elements
|
||||
* @tparam CDataType Data type of tensor C elements
|
||||
* @tparam BlockM M-dimension of the MMA tile
|
||||
* @tparam BlockN N-dimension of the MMA tile
|
||||
* @tparam BlockK K-dimension of the MMA tile
|
||||
* @tparam FragM M-dimension of the MMA tile
|
||||
* @tparam FragN N-dimension of the MMA tile
|
||||
* @tparam FragK K-dimension of the MMA tile
|
||||
* @tparam BlockSize HIP block size
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t BlockM,
|
||||
uint32_t BlockN,
|
||||
uint32_t BlockK,
|
||||
uint32_t FragM,
|
||||
uint32_t FragN,
|
||||
uint32_t FragK,
|
||||
uint32_t BlockSize>
|
||||
struct MmaLayoutTestKernel
|
||||
{
|
||||
@@ -77,19 +77,18 @@ struct MmaLayoutTestKernel
|
||||
mma::MmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
BlockK,
|
||||
FragM,
|
||||
FragN,
|
||||
FragK,
|
||||
decltype(ck_tile::core::arch::get_compiler_target()),
|
||||
mma::MmaOpFamily::DENSE>;
|
||||
using MmaOp = typename Selector::SelectedOp;
|
||||
using MmaTraits = mma::MmaOpTraits<MmaOp>;
|
||||
using MmaOp = typename Selector::SelectedOp;
|
||||
|
||||
if constexpr(MmaTraits::IsSupported)
|
||||
if constexpr(mma::MmaOpTraits<MmaOp>::IsSupported)
|
||||
{
|
||||
using AVecType = typename MmaTraits::AVecType;
|
||||
using BVecType = typename MmaTraits::BVecType;
|
||||
using CVecType = typename MmaTraits::CVecType;
|
||||
using AVecType = typename MmaOp::AVecType;
|
||||
using BVecType = typename MmaOp::BVecType;
|
||||
using CVecType = typename MmaOp::CVecType;
|
||||
constexpr uint32_t a_vec_size = vector_traits<AVecType>::vector_size;
|
||||
constexpr uint32_t b_vec_size = vector_traits<BVecType>::vector_size;
|
||||
constexpr uint32_t c_vec_size = vector_traits<CVecType>::vector_size;
|
||||
@@ -102,9 +101,9 @@ struct MmaLayoutTestKernel
|
||||
|
||||
// get (m, k, n), where "1" should be placed for this block
|
||||
const uint32_t case_idx = static_cast<uint32_t>(blockIdx.x);
|
||||
const uint32_t m = case_idx / (MmaTraits::BlockK * MmaTraits::BlockN);
|
||||
const uint32_t k = (case_idx / MmaTraits::BlockN) % MmaTraits::BlockK;
|
||||
const uint32_t n = case_idx % MmaTraits::BlockN;
|
||||
const uint32_t m = case_idx / (MmaOp::kK * MmaOp::kN);
|
||||
const uint32_t k = (case_idx / MmaOp::kN) % MmaOp::kK;
|
||||
const uint32_t n = case_idx % MmaOp::kN;
|
||||
|
||||
// place a single "1" in A/B fragments using (lane, vecIdx) -> (row, col) mapping
|
||||
for(uint32_t v = 0; v < a_vec_size; ++v)
|
||||
@@ -174,12 +173,12 @@ bool run_mma_layout_test()
|
||||
{
|
||||
using MmaOp = typename Selector::SelectedOp;
|
||||
using MmaTraits = mma::MmaOpTraits<MmaOp>;
|
||||
using ADataType = typename MmaTraits::ADataType;
|
||||
using BDataType = typename MmaTraits::BDataType;
|
||||
using CDataType = typename MmaTraits::CDataType;
|
||||
constexpr uint32_t BlockM = MmaTraits::BlockM;
|
||||
constexpr uint32_t BlockN = MmaTraits::BlockN;
|
||||
constexpr uint32_t BlockK = MmaTraits::BlockK;
|
||||
using ADataType = typename MmaOp::ADataType;
|
||||
using BDataType = typename MmaOp::BDataType;
|
||||
using CDataType = typename MmaOp::CDataType;
|
||||
constexpr uint32_t FragM = MmaOp::kM;
|
||||
constexpr uint32_t FragN = MmaOp::kN;
|
||||
constexpr uint32_t FragK = MmaOp::kK;
|
||||
constexpr auto selector_target_id = MmaTraits::CompilerTarget::TARGET_ID;
|
||||
constexpr auto selector_wave_size = MmaTraits::CompilerTarget::WAVE_SIZE_ID;
|
||||
|
||||
@@ -202,7 +201,7 @@ bool run_mma_layout_test()
|
||||
return false;
|
||||
}
|
||||
|
||||
constexpr uint32_t total_cases = BlockM * BlockK * BlockN;
|
||||
constexpr uint32_t total_cases = FragM * FragK * FragN;
|
||||
ck_tile::DeviceMem d_errors(total_cases * sizeof(uint32_t));
|
||||
std::vector<uint32_t> h_errors(total_cases, 0u);
|
||||
|
||||
@@ -213,9 +212,9 @@ bool run_mma_layout_test()
|
||||
using Kernel = MmaLayoutTestKernel<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
BlockK,
|
||||
FragM,
|
||||
FragN,
|
||||
FragK,
|
||||
static_cast<int>(selector_wave_size)>;
|
||||
|
||||
std::ignore =
|
||||
@@ -232,9 +231,9 @@ bool run_mma_layout_test()
|
||||
|
||||
for(uint32_t case_idx = 0; case_idx < total_cases; ++case_idx)
|
||||
{
|
||||
const uint32_t m = case_idx / (BlockK * BlockN);
|
||||
const uint32_t k = (case_idx / BlockN) % BlockK;
|
||||
const uint32_t n = case_idx % BlockN;
|
||||
const uint32_t m = case_idx / (FragK * FragN);
|
||||
const uint32_t k = (case_idx / FragN) % FragK;
|
||||
const uint32_t n = case_idx % FragN;
|
||||
|
||||
EXPECT_EQ(h_errors[case_idx], 0u) << "Mismatch for m=" << m << " k=" << k << " n=" << n;
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_traits.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_selector.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
@@ -93,12 +92,9 @@ struct RegisterMapTraits<ck_tile::core::arch::mma::amdgcn_mma<
|
||||
CompilerTarget,
|
||||
OpFamily_>;
|
||||
|
||||
using MmaTraits = ck_tile::core::arch::mma::MmaOpTraits<MmaOp>;
|
||||
static constexpr index_t WaveSize =
|
||||
static_cast<index_t>(MmaTraits::CompilerTarget::WAVE_SIZE_ID);
|
||||
static constexpr index_t AVecSize = vector_traits<typename MmaTraits::AVecType>::vector_size;
|
||||
static constexpr index_t BVecSize = vector_traits<typename MmaTraits::BVecType>::vector_size;
|
||||
static constexpr index_t CVecSize = vector_traits<typename MmaTraits::CVecType>::vector_size;
|
||||
static constexpr index_t AVecSize = vector_traits<typename MmaOp::AVecType>::vector_size;
|
||||
static constexpr index_t BVecSize = vector_traits<typename MmaOp::BVecType>::vector_size;
|
||||
static constexpr index_t CVecSize = vector_traits<typename MmaOp::CVecType>::vector_size;
|
||||
|
||||
using kABPs2RHssMajor = sequence<2, 1>;
|
||||
using kABPs2RHssMinor = sequence<1, 0>;
|
||||
@@ -176,12 +172,9 @@ struct RegisterMapTraits<ck_tile::core::arch::mma::amdgcn_mma<
|
||||
CompilerTarget,
|
||||
OpFamily_>;
|
||||
|
||||
using MmaTraits = ck_tile::core::arch::mma::MmaOpTraits<MmaOp>;
|
||||
static constexpr index_t WaveSize =
|
||||
static_cast<index_t>(MmaTraits::CompilerTarget::WAVE_SIZE_ID);
|
||||
static constexpr index_t AVecSize = vector_traits<typename MmaTraits::AVecType>::vector_size;
|
||||
static constexpr index_t BVecSize = vector_traits<typename MmaTraits::BVecType>::vector_size;
|
||||
static constexpr index_t CVecSize = vector_traits<typename MmaTraits::CVecType>::vector_size;
|
||||
static constexpr index_t AVecSize = vector_traits<typename MmaOp::AVecType>::vector_size;
|
||||
static constexpr index_t BVecSize = vector_traits<typename MmaOp::BVecType>::vector_size;
|
||||
static constexpr index_t CVecSize = vector_traits<typename MmaOp::CVecType>::vector_size;
|
||||
|
||||
using kABPs2RHssMajor = sequence<2, 1>;
|
||||
using kABPs2RHssMinor = sequence<0, 0>;
|
||||
@@ -192,29 +185,41 @@ struct RegisterMapTraits<ck_tile::core::arch::mma::amdgcn_mma<
|
||||
using kCYs2RHsMajor = sequence<1>;
|
||||
using kCYs2RHsMinor = sequence<1>;
|
||||
|
||||
using AWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<MmaOp::kAMLane>, sequence<MmaOp::kABKLane, MmaOp::kABKPerLane>>,
|
||||
tuple<kABPs2RHssMajor>,
|
||||
tuple<kABPs2RHssMinor>,
|
||||
kABYs2RHsMajor,
|
||||
kABYs2RHsMinor>;
|
||||
// TODO: remove these and fix constants in amdgcn_mma
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABKLane = 4;
|
||||
static constexpr index_t kABKPerLane = 4;
|
||||
static constexpr index_t kCMLane = 4;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
static constexpr index_t kCM0PerLane = 1;
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
using BWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<MmaOp::kBNLane>, sequence<MmaOp::kABKLane, MmaOp::kABKPerLane>>,
|
||||
tuple<kABPs2RHssMajor>,
|
||||
tuple<kABPs2RHssMinor>,
|
||||
kABYs2RHsMajor,
|
||||
kABYs2RHsMinor>;
|
||||
using AWarpDstrEncoding =
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<kAMLane>, sequence<kABKLane, kABKPerLane>>,
|
||||
tuple<kABPs2RHssMajor>,
|
||||
tuple<kABPs2RHssMinor>,
|
||||
kABYs2RHsMajor,
|
||||
kABYs2RHsMinor>;
|
||||
|
||||
using CWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<MmaOp::kCMLane, MmaOp::kCM1PerLane>, sequence<MmaOp::kCNLane>>,
|
||||
tuple<kCPs2RHssMajor>,
|
||||
tuple<kCPs2RHssMinor>,
|
||||
kCYs2RHsMajor,
|
||||
kCYs2RHsMinor>;
|
||||
using BWarpDstrEncoding =
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<kBNLane>, sequence<kABKLane, kABKPerLane>>,
|
||||
tuple<kABPs2RHssMajor>,
|
||||
tuple<kABPs2RHssMinor>,
|
||||
kABYs2RHsMajor,
|
||||
kABYs2RHsMinor>;
|
||||
|
||||
using CWarpDstrEncoding =
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<kCMLane, kCM1PerLane>, sequence<kCNLane>>,
|
||||
tuple<kCPs2RHssMajor>,
|
||||
tuple<kCPs2RHssMinor>,
|
||||
kCYs2RHsMajor,
|
||||
kCYs2RHsMinor>;
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -245,12 +250,9 @@ struct RegisterMapTraits<ck_tile::core::arch::mma::amdgcn_mma<
|
||||
CompilerTarget,
|
||||
OpFamily_>;
|
||||
|
||||
using MmaTraits = ck_tile::core::arch::mma::MmaOpTraits<MmaOp>;
|
||||
static constexpr index_t WaveSize =
|
||||
static_cast<index_t>(MmaTraits::CompilerTarget::WAVE_SIZE_ID);
|
||||
static constexpr index_t AVecSize = vector_traits<typename MmaTraits::AVecType>::vector_size;
|
||||
static constexpr index_t BVecSize = vector_traits<typename MmaTraits::BVecType>::vector_size;
|
||||
static constexpr index_t CVecSize = vector_traits<typename MmaTraits::CVecType>::vector_size;
|
||||
static constexpr index_t AVecSize = vector_traits<typename MmaOp::AVecType>::vector_size;
|
||||
static constexpr index_t BVecSize = vector_traits<typename MmaOp::BVecType>::vector_size;
|
||||
static constexpr index_t CVecSize = vector_traits<typename MmaOp::CVecType>::vector_size;
|
||||
|
||||
using kABPs2RHssMajor = sequence<0, 1>;
|
||||
using kABPs2RHssMinor = sequence<0, 0>;
|
||||
|
||||
@@ -144,32 +144,29 @@ TEST(SparseMMATrait, SparseSelector)
|
||||
template <typename AType,
|
||||
typename BType,
|
||||
typename CType,
|
||||
uint32_t FragM,
|
||||
uint32_t FragN,
|
||||
uint32_t FragK>
|
||||
uint32_t WaveTileM,
|
||||
uint32_t WaveTileN,
|
||||
uint32_t WaveTileK>
|
||||
__global__ void test_sparse_accum_over_k(void* a, void* b, void* c, void* out)
|
||||
{
|
||||
using CompilerTarget = decltype(get_compiler_target());
|
||||
using Selector = MmaDefaultSelector<AType,
|
||||
BType,
|
||||
CType,
|
||||
FragM,
|
||||
FragN,
|
||||
FragK,
|
||||
WaveTileM,
|
||||
WaveTileN,
|
||||
WaveTileK,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::SPARSE>;
|
||||
using MmaOp = typename Selector::SelectedOp;
|
||||
using CVecType = typename MmaOp::CVecType;
|
||||
|
||||
using MmaOp = typename Selector::SelectedOp;
|
||||
using MmaTraits = MmaOpTraits<MmaOp>;
|
||||
|
||||
using CVecType = typename MmaOp::CVecType;
|
||||
|
||||
static constexpr uint32_t kIters = FragK / MmaTraits::BlockK;
|
||||
static constexpr uint32_t kIters = WaveTileK / MmaOp::kK;
|
||||
|
||||
// Initialize the accumulator
|
||||
CVecType result = *reinterpret_cast<typename MmaOp::CVecType*>(c);
|
||||
|
||||
// Accumulate input AxB over FragK/BlockK iterations
|
||||
// Accumulate input AxB over WaveTileK/FragK iterations
|
||||
for(uint32_t i = 0; i < kIters; ++i)
|
||||
{
|
||||
result = MmaOp::exec(*reinterpret_cast<typename MmaOp::AVecType*>(a),
|
||||
@@ -210,21 +207,21 @@ TEST(SparseMMATrait, MmaSelector_Sparse_F16_F16_F32_16x16x32_Real)
|
||||
using BType = fp16_t;
|
||||
using CType = fp32_t;
|
||||
|
||||
// Fragment size, also the expected block size from the selector.
|
||||
// Note: Actual blockK might be slightly different due to hardware implementation, but the
|
||||
// WaveTile size, also the expected fragment size (MmaTile) from the selector.
|
||||
// Note: Actual FragK might be slightly different due to hardware implementation, but the
|
||||
// test_accum_over_k kernel will loop over the K dimension to ensure that the total K is
|
||||
// correct.
|
||||
static constexpr uint32_t FragM = 16;
|
||||
static constexpr uint32_t FragN = 16;
|
||||
static constexpr uint32_t FragK = 32;
|
||||
static constexpr uint32_t BlockM = FragM;
|
||||
static constexpr uint32_t BlockN = FragN;
|
||||
static constexpr uint32_t BlockK = FragK;
|
||||
static constexpr uint32_t WaveTileM = 16;
|
||||
static constexpr uint32_t WaveTileN = 16;
|
||||
static constexpr uint32_t WaveTileK = 32;
|
||||
static constexpr uint32_t FragM = WaveTileM;
|
||||
static constexpr uint32_t FragN = WaveTileN;
|
||||
static constexpr uint32_t FragK = WaveTileK;
|
||||
|
||||
// The number of elements per thread
|
||||
uint32_t AElements = BlockM * BlockK / deviceWarpSize;
|
||||
uint32_t BElements = BlockN * BlockK / deviceWarpSize;
|
||||
uint32_t CElements = BlockM * BlockN / deviceWarpSize;
|
||||
uint32_t AElements = FragM * FragK / deviceWarpSize;
|
||||
uint32_t BElements = FragN * FragK / deviceWarpSize;
|
||||
uint32_t CElements = FragM * FragN / deviceWarpSize;
|
||||
|
||||
uint32_t ASize = AElements * sizeof(AType);
|
||||
uint32_t BSize = BElements * sizeof(BType);
|
||||
|
||||
Reference in New Issue
Block a user