mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK Tile] Add Tile Distribution Encoding Calculator (#5515)
## Motivation We want to be able to calculate TileDistributionEncodings describing register mappings for any MmaOp. This is necessary for further integration with CK Tile. This MR adds a new struct TileDistrEncCalc, which takes an amdgcn_mma type (MmaOp) and provides ABC warp distribution encodings for mapping matrix fragment coordinates to register coordinates (lane, vector item) and vice versa. It is able to take CTranpose, Swizzle, and NumAccessA / NumAccessB template parameters for tweaking the tile distributions. Swizzle modification will be implemented later. The current implementation can deal with all intrinsic types and block-hiding. This MR also adds some additional static asserts and derived params within amdgcn_mma_base, to enforce consistency and help calculate Tile Distributions for block-hiding intrinsics. An Example was added that uses the Tile Distr Enc Calc to calc and print register layouts for Tile Distributions for some of our amdgcn_mma structs. It also makes sure that the CTranspose modifier works as intended. Some additional gfx9 intrinsics were added to test block-hiding layouts for the different types of C-block-hiding layouts. The sparse intrinsic wrappers were updated according to Chris's recent changes in another branch (https://github.com/ROCm/rocm-libraries/pull/5508), which moved the compression step outside of the intrinsic itself. This is necessary to make sure that the Calculator can deal with this new interpretation of the sparse intrinsics. I directly copied the new amdgcn structs from Chris's branch and changed nothing else to avoid more complex merges in the future. Note that this means I did not update a bunch of related sparse code since that would be a lot, and therefore I disabled test_amdgcn_sparse_mma for now. The amdgcn_mma_layout test was refactored a bit: - The old register mapping utility was removed and its use was replaced by the new TileDistrEncCalc - More tests were added to test layouts for different types of block-hiding and sparse intrinsics - The Selector method was removed and the tests were split up over target architectures, with each target arch having a direct list of amdgcn structs to be tested. This ensures that we force specific tests on specific architectures and makes sure that the selector doesn't quietly do some workarounds like creating compound intrinsics. ## Test Results Layout tests based on calculated tile distribution encodings pass on all architectures. Calculator works for all currently added amdgcn structs, which includes different types of block-hiding and sparse intrinsics. Printed layouts from new example verified by eye. CTranspose modifier tested for large set of intrinsics.
This commit is contained in:
committed by
GitHub
parent
bfe574a430
commit
9fe98c864f
@@ -2,3 +2,4 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
add_executable(tile_example_tile_distr_enc_reg_map example_tile_distr_enc_reg_map.cpp)
|
||||
add_executable(tile_example_tile_distr_enc_calc example_tile_distr_enc_calc.cpp)
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstdio>
|
||||
#include <type_traits>
|
||||
#include <tuple>
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp"
|
||||
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_calculator.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
|
||||
using namespace ck_tile;
|
||||
using namespace ck_tile::core::arch;
|
||||
using namespace mma;
|
||||
using F16 = fp16_t;
|
||||
using F32 = fp32_t;
|
||||
using Target908 = decltype(make_amdgcn_gfx9_target<amdgcn_target_id::GFX908>());
|
||||
using Target950 = decltype(make_amdgcn_gfx9_target<amdgcn_target_id::GFX950>());
|
||||
using Target11 = decltype(make_amdgcn_gfx11_target<amdgcn_target_id::GFX1100>());
|
||||
using Target12 = decltype(make_amdgcn_gfx12_target<amdgcn_target_id::GFX1201>());
|
||||
|
||||
template <typename MmaOp>
|
||||
int check_tile_distr_enc()
|
||||
{
|
||||
using AEnc = typename TileDistrEncCalc<MmaOp>::AWarpDstrEncoding;
|
||||
using BEnc = typename TileDistrEncCalc<MmaOp>::BWarpDstrEncoding;
|
||||
using CEnc = typename TileDistrEncCalc<MmaOp>::CWarpDstrEncoding;
|
||||
|
||||
TileDistrEncRegMap<AEnc>::print();
|
||||
TileDistrEncRegMap<BEnc>::print();
|
||||
TileDistrEncRegMap<CEnc>::print();
|
||||
|
||||
// The only thing we check here is that CTranspose works as expected.
|
||||
using AEncTransp = typename TileDistrEncCalc<MmaOp, true>::AWarpDstrEncoding;
|
||||
using BEncTransp = typename TileDistrEncCalc<MmaOp, true>::BWarpDstrEncoding;
|
||||
using CEncTransp = typename TileDistrEncCalc<MmaOp, true>::CWarpDstrEncoding;
|
||||
|
||||
// When using TransposeC, the A and B matrix layouts should be swapped.
|
||||
static_assert(std::is_same<AEncTransp, BEnc>());
|
||||
static_assert(std::is_same<BEncTransp, AEnc>());
|
||||
|
||||
// Make sure the C matrix layout is transposed in the CTranspose case.
|
||||
int err = 0;
|
||||
for(index_t lane = 0; lane < TileDistrEncRegMap<CEnc>::num_lanes; lane++)
|
||||
{
|
||||
for(index_t vec = 0; vec < TileDistrEncRegMap<CEnc>::num_vector_items; vec++)
|
||||
{
|
||||
auto coords = TileDistrEncRegMap<CEnc>::calc_matrix_indices_from_lane_vector(lane, vec);
|
||||
auto coords_transp =
|
||||
TileDistrEncRegMap<CEncTransp>::calc_matrix_indices_from_lane_vector(lane, vec);
|
||||
|
||||
if(coords[0] != coords_transp[1] || coords[1] != coords_transp[0])
|
||||
{
|
||||
err = 1;
|
||||
printf("\033[31mLane %2d vec %2d maps to C matrix coords %2d %2d and transposed C "
|
||||
"matrix coords %2d %2d, inconsistent!\033[0m\n",
|
||||
lane,
|
||||
vec,
|
||||
coords[0],
|
||||
coords[1],
|
||||
coords_transp[0],
|
||||
coords_transp[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return err;
|
||||
}
|
||||
|
||||
// List of intrinsics to test.
|
||||
// clang-format off
|
||||
using Intrinsics = ck_tile::tuple<
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_16x16x16f16
|
||||
amdgcn_mma<F16, F16, F32, 64u, 32u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_32x32x4f16
|
||||
amdgcn_mma<F16, F16, F32, 32u, 64u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_32x32x4f16
|
||||
amdgcn_mma<F16, F16, F32, 64u, 4u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_4x4x4f16
|
||||
amdgcn_mma<F16, F16, F32, 4u, 64u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_4x4x4f16
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, DefaultMfmaCtrlFlags, Target950, MmaOpFamily::DENSE>, // mfma_f32_16x16x32_f16
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags<F16, F16, F32>, Target11, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_f16_w32
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags<F16, F16, F32>, Target12, MmaOpFamily::DENSE> // wmma_f32_16x16x16_f16_w32_gfx12
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
int main()
|
||||
{
|
||||
int err = 0;
|
||||
static_for<0, Intrinsics::size(), 1>{}([&](auto i) {
|
||||
using MmaOp = std::tuple_element_t<i.value, Intrinsics>;
|
||||
err |= check_tile_distr_enc<MmaOp>();
|
||||
});
|
||||
return err;
|
||||
}
|
||||
@@ -32,6 +32,7 @@
|
||||
#include "ck_tile/core/arch/mma/sparse/sparse_transforms.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/wmma/selector.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp"
|
||||
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_calculator.hpp"
|
||||
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp"
|
||||
#include "ck_tile/core/arch/mma/wmma/wmma.hpp"
|
||||
#include "ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp"
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/wmma/wmma_traits.hpp"
|
||||
#include "ck_tile/core/arch/mma/mfma/mfma_traits.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
@@ -87,7 +89,7 @@ namespace ck_tile::core::arch::mma {
|
||||
*
|
||||
* (logical correctness). Applies to scale MFMA fp8, which due to the index matrix layout does not
|
||||
* allow arbitrary K perms to simplify layouts. This means the layout can only properly be described
|
||||
* with a Num Access value of at least 2.
|
||||
* with a Num Access value which is a multiple of 2.
|
||||
*
|
||||
* (load / store manipulation). It seems like the load and store tile functions end up looking for
|
||||
* the size of the smallest unmerged K dimension (K0) to determine how many elements should be
|
||||
@@ -102,13 +104,16 @@ namespace ck_tile::core::arch::mma {
|
||||
*
|
||||
* -- CMPerLane --
|
||||
* The number of M dim elements in each lane. In terms of unmerge sizes, it's equal to M0 * M2, i.e
|
||||
* the product of the sizes of the outermost and innermost dimensions after a double M unmerge.
|
||||
* the product of the sizes of the outermost and innermost dimensions after a double M unmerge. This
|
||||
* does not count a potential increased M dimension size from block hiding. In this case, we have M
|
||||
* = kCMBlock * M2 * M1 * M0 instead.
|
||||
*
|
||||
* -- CNumAccess --
|
||||
* Same as A / B NumAccess but for the M dim (so M2), but the mid-level code doesn't care about this
|
||||
* and will not try to request a specific value. Absolutely needed for logical correctness of
|
||||
* register mappings since we can not perform arbitrary M permutations without messing up the A
|
||||
* layout.
|
||||
* layout. This does not count a potential increased M dimension size from block hiding. In this
|
||||
* case, we have M = kCMBlock * M2 * M1 * M0 instead.
|
||||
*/
|
||||
|
||||
/**
|
||||
@@ -144,7 +149,7 @@ struct amdgcn_mma_base
|
||||
using CDataType = CDataType_;
|
||||
|
||||
// Fragment (MmaTile) sizes, check description above.
|
||||
static constexpr index_t kM = FragM; // M = M2 * M1 * M0
|
||||
static constexpr index_t kM = FragM; // M = M2 * M1 * M0 (* kCMBlocks when block-hiding)
|
||||
static constexpr index_t kN = FragN;
|
||||
static constexpr index_t kK = FragK; // K = K2 * K1 * K0
|
||||
|
||||
@@ -157,15 +162,37 @@ struct amdgcn_mma_base
|
||||
static constexpr index_t kCMPerLane = kCMPerLane_; // M2 * M0
|
||||
static constexpr index_t kCMNumAccess = kCMNumAccess_; // M2
|
||||
|
||||
// K-dimension compression ratio for A matrix, always 2 for sparse intrinsics.
|
||||
static constexpr index_t kCompressionRatio = (OpFamily == MmaOpFamily::SPARSE) ? 2 : 1;
|
||||
|
||||
// Layout checks
|
||||
static_assert(kK % kABKPerLane == 0);
|
||||
static_assert(kABKPerLane % kAKNumAccess == 0);
|
||||
static_assert(kABKPerLane % kBKNumAccess == 0);
|
||||
static_assert(kCMPerLane % kCMNumAccess == 0);
|
||||
|
||||
// Register types (derived)
|
||||
static constexpr index_t WaveSize = WaveSize_;
|
||||
static_assert((kM * kK * kARepeat) % WaveSize == 0);
|
||||
static_assert((kM * kK * kARepeat) % (WaveSize * kCompressionRatio) == 0);
|
||||
static_assert((kN * kK * kBRepeat) % WaveSize == 0);
|
||||
static_assert((kM * kN) % WaveSize == 0);
|
||||
|
||||
using AVecType = ext_vector_t<ADataType, kM * kK * kARepeat / WaveSize>;
|
||||
using AVecType = ext_vector_t<ADataType, kM * kK * kARepeat / WaveSize / kCompressionRatio>;
|
||||
using BVecType = ext_vector_t<BDataType, kN * kK * kBRepeat / WaveSize>;
|
||||
using CVecType = ext_vector_t<CDataType, kM * kN / WaveSize>;
|
||||
|
||||
// Block-hiding / repeat related traits (derived)
|
||||
static_assert(kARepeat == kBRepeat || !std::is_same_v<OpType, WmmaOp>);
|
||||
static_assert(kARepeat == 1 || kBRepeat == 1 || !std::is_same_v<OpType, MfmaOp>);
|
||||
static constexpr index_t kCMBlocks = std::is_same_v<OpType, MfmaOp> ? kBRepeat : 1;
|
||||
static constexpr index_t kCNBlocks = std::is_same_v<OpType, MfmaOp> ? kARepeat : 1;
|
||||
static_assert(kM % (kCMBlocks * kCMPerLane) == 0);
|
||||
static_assert(kN % kCNBlocks == 0);
|
||||
|
||||
// For the C matrix, the block dimension B is either put in the Vector dimension or the Lane
|
||||
// dimension. We can tell which by checking if we get the right Vector size.
|
||||
static constexpr bool CBlockDimInVecDim =
|
||||
kCMBlocks * kCNBlocks * kCMPerLane == vector_traits<CVecType>::vector_size;
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -181,6 +208,7 @@ struct Unsupported;
|
||||
* @concept MmaOpI
|
||||
* @brief Expresses the meta-data interface required for each MmaOp policy.
|
||||
*/
|
||||
// TODO: Make sure this actually matches amdgcn_mma.
|
||||
template <typename MmaOp>
|
||||
concept MmaOpI = requires(MmaOp op) {
|
||||
// Requires an op context
|
||||
@@ -194,7 +222,6 @@ concept MmaOpI = requires(MmaOp op) {
|
||||
typename MmaOp::AVecType;
|
||||
typename MmaOp::BVecType;
|
||||
typename MmaOp::CVecType;
|
||||
|
||||
// Captures CK-specific layout properties
|
||||
{ MmaOp::kABKPerLane } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kAKNumAccess } -> std::convertible_to<unsigned int>;
|
||||
|
||||
@@ -51,6 +51,82 @@ struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarg
|
||||
}
|
||||
};
|
||||
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 64u, 32u, 4u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx9_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 64u, 32u, 4u, 64u, 4, 1, 1, 1, 2, 16, 4, MfmaOp, MmaOpFamily::DENSE>
|
||||
// clang-format on
|
||||
{
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
|
||||
{
|
||||
return {__builtin_amdgcn_mfma_f32_32x32x4f16(aVec,
|
||||
bVec,
|
||||
cVec,
|
||||
static_cast<int>(CtrlFlags::Cbsz),
|
||||
static_cast<int>(CtrlFlags::Abid),
|
||||
static_cast<int>(CtrlFlags::Blgp))};
|
||||
}
|
||||
};
|
||||
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 32u, 64u, 4u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx9_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 32u, 64u, 4u, 64u, 4, 1, 2, 1, 1, 16, 4, MfmaOp, MmaOpFamily::DENSE>
|
||||
// clang-format on
|
||||
{
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
|
||||
{
|
||||
return {__builtin_amdgcn_mfma_f32_32x32x4f16(aVec,
|
||||
bVec,
|
||||
cVec,
|
||||
static_cast<int>(CtrlFlags::Cbsz),
|
||||
static_cast<int>(CtrlFlags::Abid),
|
||||
static_cast<int>(CtrlFlags::Blgp))};
|
||||
}
|
||||
};
|
||||
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 64u, 4u, 4u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx9_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 64u, 4u, 4u, 64u, 4, 1, 1, 1, 16, 4, 1, MfmaOp, MmaOpFamily::DENSE>
|
||||
// clang-format on
|
||||
{
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
|
||||
{
|
||||
return {__builtin_amdgcn_mfma_f32_4x4x4f16(aVec,
|
||||
bVec,
|
||||
cVec,
|
||||
static_cast<int>(CtrlFlags::Cbsz),
|
||||
static_cast<int>(CtrlFlags::Abid),
|
||||
static_cast<int>(CtrlFlags::Blgp))};
|
||||
}
|
||||
};
|
||||
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 4u, 64u, 4u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx9_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 4u, 64u, 4u, 64u, 4, 1, 16, 1, 1, 4, 1, MfmaOp, MmaOpFamily::DENSE>
|
||||
// clang-format on
|
||||
{
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
|
||||
{
|
||||
return {__builtin_amdgcn_mfma_f32_4x4x4f16(aVec,
|
||||
bVec,
|
||||
cVec,
|
||||
static_cast<int>(CtrlFlags::Cbsz),
|
||||
static_cast<int>(CtrlFlags::Abid),
|
||||
static_cast<int>(CtrlFlags::Blgp))};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for MFMA on GFX950 targets
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/sparse_traits.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
@@ -31,25 +30,12 @@ struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarg
|
||||
// clang-format on
|
||||
{
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) -> CVecType
|
||||
{
|
||||
static constexpr index_t ABVecN = vector_traits<AVecType>::vector_size;
|
||||
static constexpr index_t kCompressionRatio = 2;
|
||||
static constexpr index_t CompressedSize = ABVecN / kCompressionRatio;
|
||||
using AVecCompressed = ext_vector_t<fp16_t, CompressedSize>;
|
||||
|
||||
static_assert(CompressedSize == 4);
|
||||
// TODO: Compressing A on-the-fly should be OK for now, but we need to validate
|
||||
// and evaluate changing this to a transform at a higher level.
|
||||
// aVec not being const can cause problems when running multiple intrinsics.
|
||||
const uint32_t idx = ck_tile::compress_a_impl<fp16_t, CompressedSize>(aVec);
|
||||
|
||||
const AVecCompressed a_vec_pruned = {aVec[0], aVec[1], aVec[2], aVec[3]};
|
||||
|
||||
using namespace sparse::detail;
|
||||
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
|
||||
return {__builtin_amdgcn_smfmac_f32_16x16x32_f16(
|
||||
a_vec_pruned, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
|
||||
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -43,18 +43,15 @@ struct BuiltinParams
|
||||
template <SparseCompressionIndex Idx>
|
||||
static constexpr BuiltinParams getBuiltinParams()
|
||||
{
|
||||
BuiltinParams params;
|
||||
// TODO c++20: designated initializers
|
||||
if constexpr(Idx == SparseCompressionIndex::FIRST)
|
||||
{
|
||||
params.UseFirstIndex = 1;
|
||||
params.ByteIndexToOverride = 0;
|
||||
return BuiltinParams{1, 0};
|
||||
}
|
||||
else
|
||||
{
|
||||
params.UseFirstIndex = 0;
|
||||
params.ByteIndexToOverride = static_cast<int>(Idx);
|
||||
return BuiltinParams{0, static_cast<int>(Idx)};
|
||||
}
|
||||
return params;
|
||||
}
|
||||
|
||||
} // namespace sparse::detail
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/sparse_traits.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
@@ -21,23 +20,9 @@ struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarg
|
||||
// clang-format on
|
||||
{
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) -> CVecType
|
||||
{
|
||||
static constexpr index_t ABVecN = vector_traits<AVecType>::vector_size;
|
||||
static constexpr index_t kCompressionRatio = 2;
|
||||
static constexpr index_t CompressedSize = ABVecN / kCompressionRatio;
|
||||
using AVecCompressed = ext_vector_t<fp16_t, CompressedSize>;
|
||||
|
||||
static_assert(CompressedSize == 8);
|
||||
// TODO: Compressing A on-the-fly should be OK for now, but we need to validate
|
||||
// and evaluate changing this to a transform at a higher level.
|
||||
// aVec not being const can cause problems when running multiple intrinsics.
|
||||
const uint32_t idx = ck_tile::compress_a_impl<fp16_t, CompressedSize>(aVec);
|
||||
|
||||
const AVecCompressed a_vec_pruned = {
|
||||
aVec[0], aVec[1], aVec[2], aVec[3], aVec[4], aVec[5], aVec[6], aVec[7]};
|
||||
|
||||
return {__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32(a_vec_pruned, bVec, cVec, idx)};
|
||||
return {__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32(aVec, bVec, cVec, idx)};
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -0,0 +1,114 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
/**
|
||||
* @class TileDistrEncCalc
|
||||
* @brief Given an MmaOp and modifiers, provides warp-level tile distribution encodings for mapping
|
||||
* ABC matrix fragment coordinates to register coordinates (lane, vector item) and vice versa.
|
||||
* @tparam MmaOp Intrinsic (amdgcn_mma).
|
||||
* @tparam CTranspose Whether we are using CTranspose.
|
||||
* @tparam SFactor Swizzle factor. Not implemented.
|
||||
* @tparam AttrNumAccessA Requested NumAccess for the A matrix. Must be multiple of "fundamental"
|
||||
* NumAccess for intrinsic. See details in amdgcn_mma.hpp.
|
||||
* @tparam AttrNumAccessB Requested NumAccess for the B matrix.
|
||||
*/
|
||||
template <typename MmaOp,
|
||||
bool CTranspose = false,
|
||||
index_t SFactor = 1,
|
||||
index_t AttrNumAccessA = MmaOp::kAKNumAccess,
|
||||
index_t AttrNumAccessB = MmaOp::kBKNumAccess>
|
||||
struct TileDistrEncCalc
|
||||
{
|
||||
private:
|
||||
static constexpr index_t NumAccessA = std::max(MmaOp::kAKNumAccess, AttrNumAccessA);
|
||||
static constexpr index_t NumAccessB = std::max(MmaOp::kBKNumAccess, AttrNumAccessB);
|
||||
|
||||
// We are free to choose any NumAccess value to manipulate the load / store behavior, unless the
|
||||
// intrinsic fundamentally requires a base NumAccess factor for the layout to be correct.
|
||||
static_assert(AttrNumAccessA % MmaOp::kAKNumAccess == 0,
|
||||
"Requesting NumAccessA incompatible with builtin.");
|
||||
static_assert(AttrNumAccessB % MmaOp::kBKNumAccess == 0,
|
||||
"Requesting NumAccessB incompatible with builtin.");
|
||||
|
||||
static_assert(MmaOp::kABKPerLane % NumAccessA == 0);
|
||||
static_assert(MmaOp::kABKPerLane % NumAccessB == 0);
|
||||
static_assert(SFactor == 1, "Swizzle not implemented yet."); // TODO: Implement Swizzle.
|
||||
|
||||
template <index_t MajorDimSize, index_t Repeat, index_t NumAccess, index_t CompressionRatio = 1>
|
||||
using ABWarpDstrEnc = tile_distribution_encoding<
|
||||
sequence<Repeat>,
|
||||
tuple<sequence<MajorDimSize>,
|
||||
sequence<NumAccess,
|
||||
MmaOp::kK / MmaOp::kABKPerLane,
|
||||
MmaOp::kABKPerLane / NumAccess / CompressionRatio>>,
|
||||
tuple<sequence<0, 2, 1>>,
|
||||
tuple<sequence<0, 1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>;
|
||||
|
||||
static constexpr auto get_cwarp_dstr_encoding()
|
||||
{
|
||||
// We unmerge the M and N dimensions in the same way every time.
|
||||
using MSubDims = sequence<MmaOp::kCMBlocks,
|
||||
MmaOp::kCMNumAccess,
|
||||
MmaOp::kM / MmaOp::kCMBlocks / MmaOp::kCMPerLane,
|
||||
MmaOp::kCMPerLane / MmaOp::kCMNumAccess>;
|
||||
using NSubDims = sequence<MmaOp::kCNBlocks, MmaOp::kN / MmaOp::kCNBlocks>;
|
||||
|
||||
// In case of CTranspose, all we do is swap the M and N dimension.
|
||||
using MatDims =
|
||||
std::conditional_t<CTranspose, tuple<NSubDims, MSubDims>, tuple<MSubDims, NSubDims>>;
|
||||
constexpr int MInx = CTranspose ? 2 : 1;
|
||||
constexpr int NInx = CTranspose ? 1 : 2;
|
||||
|
||||
// For MFMA intrinsics with blocks, the block dimensions might be in the Lane dim or in the
|
||||
// Vec dim, so we get different merge orderings.
|
||||
if constexpr(MmaOp::CBlockDimInVecDim)
|
||||
{
|
||||
return tile_distribution_encoding<sequence<1>,
|
||||
MatDims,
|
||||
tuple<sequence<MInx, NInx>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
sequence<MInx, NInx, MInx, MInx>,
|
||||
sequence<0, 0, 1, 3>>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return tile_distribution_encoding<sequence<1>,
|
||||
MatDims,
|
||||
tuple<sequence<MInx, NInx, MInx, NInx>>,
|
||||
tuple<sequence<0, 0, 2, 1>>,
|
||||
sequence<MInx, MInx>,
|
||||
sequence<1, 3>>{};
|
||||
}
|
||||
}
|
||||
|
||||
using AEnc_ = ABWarpDstrEnc<MmaOp::kM, MmaOp::kARepeat, NumAccessA, MmaOp::kCompressionRatio>;
|
||||
using BEnc_ = ABWarpDstrEnc<MmaOp::kN, MmaOp::kBRepeat, NumAccessB>;
|
||||
|
||||
public:
|
||||
// When using CTranspose, the A and B matrices are swapped.
|
||||
using AWarpDstrEncoding = std::conditional_t<CTranspose, BEnc_, AEnc_>;
|
||||
using BWarpDstrEncoding = std::conditional_t<CTranspose, AEnc_, BEnc_>;
|
||||
using CWarpDstrEncoding = decltype(get_cwarp_dstr_encoding());
|
||||
|
||||
// Some additional consistency checks
|
||||
static_assert(TileDistrEncRegMap<AWarpDstrEncoding>::num_lanes == MmaOp::WaveSize);
|
||||
static_assert(TileDistrEncRegMap<BWarpDstrEncoding>::num_lanes == MmaOp::WaveSize);
|
||||
static_assert(TileDistrEncRegMap<CWarpDstrEncoding>::num_lanes == MmaOp::WaveSize);
|
||||
|
||||
static_assert(TileDistrEncRegMap<AWarpDstrEncoding>::num_vector_items ==
|
||||
vector_traits<typename MmaOp::AVecType>::vector_size);
|
||||
static_assert(TileDistrEncRegMap<BWarpDstrEncoding>::num_vector_items ==
|
||||
vector_traits<typename MmaOp::BVecType>::vector_size);
|
||||
static_assert(TileDistrEncRegMap<CWarpDstrEncoding>::num_vector_items ==
|
||||
vector_traits<typename MmaOp::CVecType>::vector_size);
|
||||
};
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
@@ -7,10 +7,11 @@ if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9|gfx12")
|
||||
add_gtest_executable(test_amdgcn_sparse_mma test_amdgcn_sparse_mma.cpp)
|
||||
target_compile_options(test_amdgcn_sparse_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
# TODO: This test is temporarily disabled for cooperation / work planning reasons. Re-enable after merging related work.
|
||||
# if(GPU_TARGETS MATCHES "gfx9|gfx12")
|
||||
# add_gtest_executable(test_amdgcn_sparse_mma test_amdgcn_sparse_mma.cpp)
|
||||
# target_compile_options(test_amdgcn_sparse_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
# endif()
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_gtest_executable(test_amdgcn_mma test_amdgcn_mma.cpp)
|
||||
target_compile_options(test_amdgcn_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
@@ -18,10 +19,28 @@ else()
|
||||
message(DEBUG "Skipping ck_tile_gemm tests for current target")
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
|
||||
add_gtest_executable(test_amdgcn_mma_layout test_amdgcn_mma_layout.cpp)
|
||||
target_compile_options(test_amdgcn_mma_layout PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
else()
|
||||
message(DEBUG "Skipping gfx9|gfx11|gfx12 mma layout validation tests for current target")
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_gtest_executable(test_amdgcn_mma_layout_gfx9 test_amdgcn_mma_layout_gfx9.cpp)
|
||||
target_compile_options(test_amdgcn_mma_layout_gfx9 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx942|gfx950")
|
||||
add_gtest_executable(test_amdgcn_mma_layout_gfx942 test_amdgcn_mma_layout_gfx942.cpp)
|
||||
target_compile_options(test_amdgcn_mma_layout_gfx942 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx950")
|
||||
add_gtest_executable(test_amdgcn_mma_layout_gfx950 test_amdgcn_mma_layout_gfx950.cpp)
|
||||
target_compile_options(test_amdgcn_mma_layout_gfx950 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx11")
|
||||
add_gtest_executable(test_amdgcn_mma_layout_gfx11 test_amdgcn_mma_layout_gfx11.cpp)
|
||||
target_compile_options(test_amdgcn_mma_layout_gfx11 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx12")
|
||||
add_gtest_executable(test_amdgcn_mma_layout_gfx12 test_amdgcn_mma_layout_gfx12.cpp)
|
||||
target_compile_options(test_amdgcn_mma_layout_gfx12 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
|
||||
|
||||
@@ -1,304 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
#include "ck_tile/host/stream_config.hpp"
|
||||
#include "ck_tile/host/device_memory.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/utility/env.hpp"
|
||||
|
||||
#include "test/ck_tile/core/arch/mma/test_amdgcn_mma_layout_util.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <vector>
|
||||
#include <cstddef>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
namespace ck = ck_tile;
|
||||
namespace mma = ck_tile::core::arch::mma;
|
||||
|
||||
// MMA register layout validation test for amdgcn_mma structs.
|
||||
//
|
||||
// Strategy: for every (m, k, n) triple in the tile, the test constructs a pair of input tensors
|
||||
// A and B that contain exactly one non-zero element each, placed so that their product
|
||||
// contributes to a single output element C(m, n):
|
||||
//
|
||||
// A (M x K) B (K x N) C = A * B (M x N)
|
||||
// . . . . . . . . . . . . . . . . . . . . . . . .
|
||||
// . . . . . . . . . . . . . . . . . . . . . . . .
|
||||
// . . . 1 . . . . . . . . . . . . . . . . . . . .
|
||||
// . . . . . . . . . . . 1 . . . . . . . . . 1 . .
|
||||
// . . . . . . . . . . . . . . . . . . . . . . . .
|
||||
// A(m,k) = 1 B(k,n) = 1 C(m,n) = 1
|
||||
//
|
||||
// The kernel uses RegisterMap to scatter A and B into the correct (lane, vecIdx) positions
|
||||
// of the MMA fragment registers, executes the intrinsic, then uses RegisterMap again to
|
||||
// gather back into C matrix. The position of "1" in C is checked against the expected (m, n)
|
||||
// location.
|
||||
|
||||
namespace {
|
||||
|
||||
/**
|
||||
* @class MmaLayoutTestKernel
|
||||
* @brief Device kernel that performs C = AB using a given Mma op
|
||||
*
|
||||
* @tparam ADataType Data type of tensor A elements
|
||||
* @tparam BDataType Data type of tensor B elements
|
||||
* @tparam CDataType Data type of tensor C elements
|
||||
* @tparam FragM M-dimension of the MMA tile
|
||||
* @tparam FragN N-dimension of the MMA tile
|
||||
* @tparam FragK K-dimension of the MMA tile
|
||||
* @tparam BlockSize HIP block size
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t FragM,
|
||||
uint32_t FragN,
|
||||
uint32_t FragK,
|
||||
uint32_t BlockSize>
|
||||
struct MmaLayoutTestKernel
|
||||
{
|
||||
static constexpr int kBlockSize = BlockSize;
|
||||
|
||||
__device__ void operator()(uint32_t* error_flags) const
|
||||
{
|
||||
using Selector =
|
||||
mma::MmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
FragM,
|
||||
FragN,
|
||||
FragK,
|
||||
decltype(ck_tile::core::arch::get_compiler_target()),
|
||||
mma::MmaOpFamily::DENSE>;
|
||||
using MmaOp = typename Selector::SelectedOp;
|
||||
|
||||
if constexpr(mma::MmaOpTraits<MmaOp>::IsSupported)
|
||||
{
|
||||
using AVecType = typename MmaOp::AVecType;
|
||||
using BVecType = typename MmaOp::BVecType;
|
||||
using CVecType = typename MmaOp::CVecType;
|
||||
constexpr uint32_t a_vec_size = vector_traits<AVecType>::vector_size;
|
||||
constexpr uint32_t b_vec_size = vector_traits<BVecType>::vector_size;
|
||||
constexpr uint32_t c_vec_size = vector_traits<CVecType>::vector_size;
|
||||
|
||||
const uint32_t lane = threadIdx.x;
|
||||
|
||||
AVecType a_frag{};
|
||||
BVecType b_frag{};
|
||||
CVecType c_frag{};
|
||||
|
||||
// get (m, k, n), where "1" should be placed for this block
|
||||
const uint32_t case_idx = static_cast<uint32_t>(blockIdx.x);
|
||||
const uint32_t m = case_idx / (MmaOp::kK * MmaOp::kN);
|
||||
const uint32_t k = (case_idx / MmaOp::kN) % MmaOp::kK;
|
||||
const uint32_t n = case_idx % MmaOp::kN;
|
||||
|
||||
// place a single "1" in A/B fragments using (lane, vecIdx) -> (row, col) mapping
|
||||
for(uint32_t v = 0; v < a_vec_size; ++v)
|
||||
{
|
||||
auto a_coords = RegisterMap<MmaOp>::Register2AMap(lane, v);
|
||||
if(static_cast<uint32_t>(a_coords[0]) == m &&
|
||||
static_cast<uint32_t>(a_coords[1]) == k)
|
||||
{
|
||||
a_frag[v] = static_cast<ADataType>(1);
|
||||
}
|
||||
}
|
||||
|
||||
for(uint32_t v = 0; v < b_vec_size; ++v)
|
||||
{
|
||||
auto b_coords = RegisterMap<MmaOp>::Register2BMap(lane, v);
|
||||
if(static_cast<uint32_t>(b_coords[0]) == n &&
|
||||
static_cast<uint32_t>(b_coords[1]) == k)
|
||||
{
|
||||
b_frag[v] = static_cast<BDataType>(1);
|
||||
}
|
||||
}
|
||||
|
||||
c_frag = MmaOp::exec(a_frag, b_frag, c_frag);
|
||||
|
||||
uint32_t err = 0;
|
||||
const CDataType tol = static_cast<CDataType>(
|
||||
1.0e-1f); // TODO: this tolerance might not be suitable for all data types and
|
||||
// should be revisited if we add more configurations
|
||||
for(uint32_t v = 0; v < c_vec_size; ++v)
|
||||
{
|
||||
auto c_coords = RegisterMap<MmaOp>::Register2CMap(lane, v);
|
||||
const uint32_t i = static_cast<uint32_t>(c_coords[0]);
|
||||
const uint32_t j = static_cast<uint32_t>(c_coords[1]);
|
||||
|
||||
const CDataType expected =
|
||||
(i == m && j == n) ? static_cast<CDataType>(1) : static_cast<CDataType>(0);
|
||||
const CDataType value = static_cast<CDataType>(c_frag[v]);
|
||||
if(fabsf(static_cast<float>(value - expected)) > static_cast<float>(tol))
|
||||
{
|
||||
err = 1;
|
||||
}
|
||||
}
|
||||
|
||||
const uint32_t any_err = __any(err);
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
error_flags[case_idx] = any_err;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Test driver: runs the test for a given MMA configuration.
|
||||
*
|
||||
* The testlaunches (mkn) test cases (one per block) to check all possible positions of the "1" in
|
||||
* the A/B tensors.
|
||||
* 1. Constructs A and B tensors with a single 1 at A(m,k) and B(k,n).
|
||||
* 2. Executes MMA intrinsic to compute C tensor.
|
||||
* 3. Checks if C has the 1 in the expected position.
|
||||
*
|
||||
* @tparam Selector Selector for the Mma operation
|
||||
* @return true if the test ran on hardware; false if skipped (no device or unsupported)
|
||||
*/
|
||||
template <typename Selector>
|
||||
bool run_mma_layout_test()
|
||||
{
|
||||
using MmaOp = typename Selector::SelectedOp;
|
||||
using MmaTraits = mma::MmaOpTraits<MmaOp>;
|
||||
using ADataType = typename MmaOp::ADataType;
|
||||
using BDataType = typename MmaOp::BDataType;
|
||||
using CDataType = typename MmaOp::CDataType;
|
||||
constexpr uint32_t FragM = MmaOp::kM;
|
||||
constexpr uint32_t FragN = MmaOp::kN;
|
||||
constexpr uint32_t FragK = MmaOp::kK;
|
||||
constexpr auto selector_target_id = MmaTraits::CompilerTarget::TARGET_ID;
|
||||
constexpr auto selector_wave_size = MmaTraits::CompilerTarget::WAVE_SIZE_ID;
|
||||
|
||||
int device_count = 0;
|
||||
hipDevice_t device{};
|
||||
HIP_CHECK_ERROR(hipGetDevice(&device));
|
||||
HIP_CHECK_ERROR(hipGetDeviceCount(&device_count));
|
||||
|
||||
hipDeviceProp_t props{};
|
||||
HIP_CHECK_ERROR(hipGetDeviceProperties(&props, device));
|
||||
|
||||
const auto runtime_target =
|
||||
ck_tile::core::arch::hip_device_prop_gcn_arch_name_to_amdgcn_target_id(props.gcnArchName);
|
||||
const bool has_device = device_count > 0;
|
||||
|
||||
if(!has_device || runtime_target == ck_tile::core::arch::amdgcn_target_id::HOST ||
|
||||
runtime_target != selector_target_id ||
|
||||
props.warpSize != static_cast<int>(selector_wave_size))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
constexpr uint32_t total_cases = FragM * FragK * FragN;
|
||||
ck_tile::DeviceMem d_errors(total_cases * sizeof(uint32_t));
|
||||
std::vector<uint32_t> h_errors(total_cases, 0u);
|
||||
|
||||
auto* d_error_ptr = static_cast<uint32_t*>(d_errors.GetDeviceBuffer());
|
||||
|
||||
std::ignore = hipGetLastError();
|
||||
|
||||
using Kernel = MmaLayoutTestKernel<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
FragM,
|
||||
FragN,
|
||||
FragK,
|
||||
static_cast<int>(selector_wave_size)>;
|
||||
|
||||
std::ignore =
|
||||
ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1},
|
||||
ck_tile::make_kernel(Kernel{},
|
||||
dim3(total_cases),
|
||||
dim3(static_cast<int>(selector_wave_size)),
|
||||
0,
|
||||
d_error_ptr));
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpyAsync(
|
||||
h_errors.data(), d_error_ptr, d_errors.GetBufferSize(), hipMemcpyDeviceToHost));
|
||||
HIP_CHECK_ERROR(hipStreamSynchronize(nullptr));
|
||||
|
||||
for(uint32_t case_idx = 0; case_idx < total_cases; ++case_idx)
|
||||
{
|
||||
const uint32_t m = case_idx / (FragK * FragN);
|
||||
const uint32_t k = (case_idx / FragN) % FragK;
|
||||
const uint32_t n = case_idx % FragN;
|
||||
|
||||
EXPECT_EQ(h_errors[case_idx], 0u) << "Mismatch for m=" << m << " k=" << k << " n=" << n;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// ==================== Test configurations per target ====================
|
||||
// TODO: currently we have only 1 specific target per test. This should be revisited to enable all
|
||||
// the targets within the family (gfx12, gfx11, gfx9)
|
||||
using MmaGfx1201CompilerTarget = decltype(ck_tile::core::arch::make_amdgcn_gfx12_target<
|
||||
ck_tile::core::arch::amdgcn_target_id::GFX1201>());
|
||||
using MmaGfx90aCompilerTarget = decltype(ck_tile::core::arch::make_amdgcn_gfx9_target<
|
||||
ck_tile::core::arch::amdgcn_target_id::GFX90A>());
|
||||
using MmaGfx1100CompilerTarget = decltype(ck_tile::core::arch::make_amdgcn_gfx11_target<
|
||||
ck_tile::core::arch::amdgcn_target_id::GFX1100>());
|
||||
|
||||
using MmaGfx1201Selector = mma::MmaDefaultSelector<ck::fp16_t,
|
||||
ck::fp16_t,
|
||||
ck::fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
MmaGfx1201CompilerTarget,
|
||||
mma::MmaOpFamily::DENSE>;
|
||||
using MmaGfx90aSelector = mma::MmaDefaultSelector<ck::fp16_t,
|
||||
ck::fp16_t,
|
||||
ck::fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
MmaGfx90aCompilerTarget,
|
||||
mma::MmaOpFamily::DENSE>;
|
||||
using MmaGfx1100Selector = mma::MmaDefaultSelector<ck::fp16_t,
|
||||
ck::fp16_t,
|
||||
ck::fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
MmaGfx1100CompilerTarget,
|
||||
mma::MmaOpFamily::DENSE>;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
MmaGfx1201Selector,
|
||||
MmaGfx90aSelector,
|
||||
MmaGfx1100Selector
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
template <typename Selector>
|
||||
class TestMmaLayout : public ::testing::Test
|
||||
{
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(TestMmaLayout, KernelTypes);
|
||||
|
||||
TYPED_TEST(TestMmaLayout, Mma_16x16x16_F16_F16_F32)
|
||||
{
|
||||
bool executed = run_mma_layout_test<TypeParam>();
|
||||
|
||||
if(!executed)
|
||||
{
|
||||
GTEST_SKIP() << "No supported HIP device found. Skipping test.";
|
||||
}
|
||||
}
|
||||
239
test/ck_tile/core/arch/mma/test_amdgcn_mma_layout.inc
Normal file
239
test/ck_tile/core/arch/mma/test_amdgcn_mma_layout.inc
Normal file
@@ -0,0 +1,239 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
#include "ck_tile/host/stream_config.hpp"
|
||||
#include "ck_tile/host/device_memory.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
|
||||
#include <cstdio>
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace ck_tile;
|
||||
using namespace ck_tile::core::arch;
|
||||
using namespace mma;
|
||||
|
||||
using F16 = fp16_t;
|
||||
using F32 = fp32_t;
|
||||
using Target908 = decltype(make_amdgcn_gfx9_target<amdgcn_target_id::GFX908>());
|
||||
using Target942 = decltype(make_amdgcn_gfx9_target<amdgcn_target_id::GFX942>());
|
||||
using Target950 = decltype(make_amdgcn_gfx9_target<amdgcn_target_id::GFX950>());
|
||||
using Target11 = decltype(make_amdgcn_gfx11_target<amdgcn_target_id::GFX1100>());
|
||||
using Target12 = decltype(make_amdgcn_gfx12_target<amdgcn_target_id::GFX1201>());
|
||||
|
||||
// MMA register layout validation test for amdgcn_mma structs.
|
||||
//
|
||||
// Strategy: for every (m, k, n) triple in the tile, the test constructs a pair of input tensors A
|
||||
// and B that contain exactly one non-zero element each, placed so that their product contributes to
|
||||
// a single output element C(m, n):
|
||||
//
|
||||
// A (M x K) B (K x N) C = A * B (M x N)
|
||||
// . . . . . . . . . . . . . . . . . . . . . . . .
|
||||
// . . . . . . . . . . . . . . . . . . . . . . . .
|
||||
// . . . 1 . . . . . . . . . . . . . . . . . . . .
|
||||
// . . . . . . . . . . . 1 . . . . . . . . . 1 . .
|
||||
// . . . . . . . . . . . . . . . . . . . . . . . .
|
||||
// A(m,k) = 1 B(k,n) = 1 C(m,n) = 1
|
||||
//
|
||||
// The kernel uses TileDistrEncRegMap to scatter A and B into the correct (lane, vecIdx) positions
|
||||
// of the MMA fragment registers, executes the intrinsic, then uses TileDistrEncRegMap again to
|
||||
// gather back into C matrix. The position of "1" in C is checked against the expected (m, n)
|
||||
// location.
|
||||
|
||||
/**
|
||||
* @class MmaLayoutTestKernel
|
||||
* @brief Device kernel that performs C = AB using a given Mma op
|
||||
* @tparam MmaOp Intrinsic (amdgcn_mma) to be tested
|
||||
*/
|
||||
template <typename MmaOp> // TODO: C++20 concept for MmaOp
|
||||
struct MmaLayoutTestKernel
|
||||
{
|
||||
static constexpr int kBlockSize = MmaOp::WaveSize;
|
||||
|
||||
__device__ void operator()(uint32_t* error_flags) const
|
||||
{
|
||||
using ARegMap = TileDistrEncRegMap<typename TileDistrEncCalc<MmaOp>::AWarpDstrEncoding>;
|
||||
using BRegMap = TileDistrEncRegMap<typename TileDistrEncCalc<MmaOp>::BWarpDstrEncoding>;
|
||||
using CRegMap = TileDistrEncRegMap<typename TileDistrEncCalc<MmaOp>::CWarpDstrEncoding>;
|
||||
|
||||
if constexpr(MmaOpTraits<MmaOp>::IsSupported)
|
||||
{
|
||||
using AVecType = typename MmaOp::AVecType;
|
||||
using BVecType = typename MmaOp::BVecType;
|
||||
using CVecType = typename MmaOp::CVecType;
|
||||
constexpr index_t a_vec_size = vector_traits<AVecType>::vector_size;
|
||||
constexpr index_t b_vec_size = vector_traits<BVecType>::vector_size;
|
||||
constexpr index_t c_vec_size = vector_traits<CVecType>::vector_size;
|
||||
|
||||
const index_t lane = threadIdx.x;
|
||||
|
||||
AVecType a_frag{};
|
||||
BVecType b_frag{};
|
||||
CVecType c_frag{};
|
||||
uint32_t sparse_idx{};
|
||||
static_assert(MmaOp::kCompressionRatio <= 2); // Allow only 4:2 compression (or no).
|
||||
|
||||
// get (m, k, n), where "1" should be placed for this block
|
||||
const index_t case_idx = blockIdx.x;
|
||||
const index_t m = case_idx / (MmaOp::kK * MmaOp::kN);
|
||||
const index_t k = (case_idx / MmaOp::kN) % MmaOp::kK;
|
||||
const index_t n = case_idx % MmaOp::kN;
|
||||
|
||||
// place a single "1" in A/B fragments using (lane, vecIdx) -> (row, col) mapping
|
||||
for(index_t v = 0; v < a_vec_size; ++v)
|
||||
{
|
||||
auto a_coords = ARegMap::calc_matrix_indices_from_lane_vector(lane, v);
|
||||
|
||||
// When dealing with sparse intrinsics, the A matrix is compressed in the K
|
||||
// direction and we just put our "1" in the k / 2 position (rounded down).
|
||||
if(a_coords[0] == m && a_coords[1] == (k / MmaOp::kCompressionRatio))
|
||||
{
|
||||
a_frag[v] = 1;
|
||||
|
||||
// Calc an appropriate sparse idx value for a single 1 in position k. We use a
|
||||
// baseline index of 0x88888888. This sends each compressed index i to
|
||||
// uncompressed index i * 2. If k is odd, we should send it to i * 2 + 1
|
||||
// instead. We update only the absolutely necessary pair of bits for this
|
||||
// (idx[v*2:v*2+1]). Note that this simple calculation works for any 4:2 sparse
|
||||
// intrinsic with up to 16 packed k elements per lane.
|
||||
sparse_idx = 0x88888888 | ((k % 2) << (v * 2));
|
||||
}
|
||||
}
|
||||
|
||||
for(index_t v = 0; v < b_vec_size; ++v)
|
||||
{
|
||||
auto b_coords = BRegMap::calc_matrix_indices_from_lane_vector(lane, v);
|
||||
if(b_coords[0] == n && b_coords[1] == k)
|
||||
{
|
||||
b_frag[v] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(MmaOpTraits<MmaOp>::IsSparse)
|
||||
{
|
||||
c_frag = MmaOp::exec(a_frag, b_frag, c_frag, sparse_idx);
|
||||
}
|
||||
else
|
||||
{
|
||||
c_frag = MmaOp::exec(a_frag, b_frag, c_frag);
|
||||
}
|
||||
|
||||
// TODO: this tolerance might not be suitable for all data types and
|
||||
// should be revisited if we add more configurations
|
||||
const float tolerance = 1.0e-1f;
|
||||
index_t err = 0;
|
||||
|
||||
for(index_t v = 0; v < c_vec_size; ++v)
|
||||
{
|
||||
auto c_coords = CRegMap::calc_matrix_indices_from_lane_vector(lane, v);
|
||||
|
||||
const float expected = (c_coords[0] == m && c_coords[1] == n) ? 1 : 0;
|
||||
const float value = static_cast<float>(c_frag[v]);
|
||||
if(std::fabs(value - expected) > tolerance)
|
||||
{
|
||||
err = 1;
|
||||
}
|
||||
}
|
||||
|
||||
const uint32_t any_err = __any(err);
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
error_flags[case_idx] = any_err;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Test driver: runs the test for a given MMA configuration.
|
||||
*
|
||||
* The testlaunches (mkn) test cases (one per block) to check all possible positions of the "1" in
|
||||
* the A/B tensors.
|
||||
* 1. Constructs A and B tensors with a single 1 at A(m,k) and B(k,n).
|
||||
* 2. Executes MMA intrinsic to compute C tensor.
|
||||
* 3. Checks if C has the 1 in the expected position.
|
||||
*
|
||||
* @tparam MmaOp Intrinsic (amdgcn_mma) to be tested
|
||||
*/
|
||||
template <typename MmaOp> // TODO: C++20 concept for MmaOp
|
||||
void run_mma_layout_test()
|
||||
{
|
||||
EXPECT_TRUE(MmaOpTraits<MmaOp>::IsSupported) << "Unsupported MmaOp! Bad MmaOp in list!\n";
|
||||
|
||||
int device_count = 0;
|
||||
hipDevice_t device{};
|
||||
HIP_CHECK_ERROR(hipGetDevice(&device));
|
||||
HIP_CHECK_ERROR(hipGetDeviceCount(&device_count));
|
||||
EXPECT_TRUE(device_count > 0) << "No device found!";
|
||||
|
||||
hipDeviceProp_t props{};
|
||||
HIP_CHECK_ERROR(hipGetDeviceProperties(&props, device));
|
||||
EXPECT_EQ(props.warpSize, static_cast<int>(MmaOp::WaveSize))
|
||||
<< "Device wavesize " << props.warpSize << " != Mma wavesize " << MmaOp::WaveSize;
|
||||
|
||||
constexpr uint32_t total_cases = MmaOp::kM * MmaOp::kN * MmaOp::kK;
|
||||
ck_tile::DeviceMem d_errors(total_cases * sizeof(uint32_t));
|
||||
std::vector<uint32_t> h_errors(total_cases, 0u);
|
||||
|
||||
auto* d_error_ptr = static_cast<uint32_t*>(d_errors.GetDeviceBuffer());
|
||||
|
||||
(void)hipGetLastError();
|
||||
|
||||
using Kernel = MmaLayoutTestKernel<MmaOp>;
|
||||
|
||||
ck_tile::launch_kernel(
|
||||
ck_tile::stream_config{nullptr, false, 0, 0, 1},
|
||||
ck_tile::make_kernel(Kernel{}, dim3(total_cases), dim3(MmaOp::WaveSize), 0, d_error_ptr));
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpyAsync(
|
||||
h_errors.data(), d_error_ptr, d_errors.GetBufferSize(), hipMemcpyDeviceToHost));
|
||||
HIP_CHECK_ERROR(hipStreamSynchronize(nullptr));
|
||||
|
||||
for(uint32_t case_idx = 0; case_idx < total_cases; ++case_idx)
|
||||
{
|
||||
const uint32_t m = case_idx / (MmaOp::kK * MmaOp::kN);
|
||||
const uint32_t k = (case_idx / MmaOp::kN) % MmaOp::kK;
|
||||
const uint32_t n = case_idx % MmaOp::kN;
|
||||
|
||||
EXPECT_EQ(h_errors[case_idx], 0u) << "Mismatch for m=" << m << " k=" << k << " n=" << n;
|
||||
}
|
||||
}
|
||||
|
||||
// Lists of intrinsics to test.
|
||||
// clang-format off
|
||||
using Gfx9Intrinsics = ::testing::Types<
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_16x16x16f16
|
||||
amdgcn_mma<F16, F16, F32, 64u, 32u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_32x32x4f16
|
||||
amdgcn_mma<F16, F16, F32, 32u, 64u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_32x32x4f16
|
||||
amdgcn_mma<F16, F16, F32, 64u, 4u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_4x4x4f16
|
||||
amdgcn_mma<F16, F16, F32, 4u, 64u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE> // mfma_f32_4x4x4f16
|
||||
>;
|
||||
using Gfx942Intrinsics = ::testing::Types<
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, DefaultSparseMfmaCtrlFlags, Target942, MmaOpFamily::SPARSE> // smfmac_f32_16x16x32_f16
|
||||
>;
|
||||
using Gfx950Intrinsics = ::testing::Types<
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, DefaultMfmaCtrlFlags, Target950, MmaOpFamily::DENSE> // mfma_f32_16x16x32_f16
|
||||
>;
|
||||
using Gfx11Intrinsics = ::testing::Types<
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags<F16, F16, F32>, Target11, MmaOpFamily::DENSE> // wmma_f32_16x16x16_f16_w32
|
||||
>;
|
||||
using Gfx12Intrinsics = ::testing::Types<
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags<F16, F16, F32>, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_f16_w32_gfx12
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, DefaultSparseMfmaCtrlFlags, Target12, MmaOpFamily::SPARSE> // swmmac_f32_16x16x32_f16_w32
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
template <typename MmaOp>
|
||||
class TestMmaLayout : public ::testing::Test
|
||||
{
|
||||
};
|
||||
} // namespace
|
||||
@@ -0,0 +1,6 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_amdgcn_mma_layout.inc"
|
||||
TYPED_TEST_SUITE(TestMmaLayout, Gfx11Intrinsics);
|
||||
TYPED_TEST(TestMmaLayout, Gfx11Intrinsics) { run_mma_layout_test<TypeParam>(); }
|
||||
@@ -0,0 +1,6 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_amdgcn_mma_layout.inc"
|
||||
TYPED_TEST_SUITE(TestMmaLayout, Gfx12Intrinsics);
|
||||
TYPED_TEST(TestMmaLayout, Gfx12Intrinsics) { run_mma_layout_test<TypeParam>(); }
|
||||
@@ -0,0 +1,6 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_amdgcn_mma_layout.inc"
|
||||
TYPED_TEST_SUITE(TestMmaLayout, Gfx9Intrinsics);
|
||||
TYPED_TEST(TestMmaLayout, Gfx9Intrinsics) { run_mma_layout_test<TypeParam>(); }
|
||||
@@ -0,0 +1,6 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_amdgcn_mma_layout.inc"
|
||||
TYPED_TEST_SUITE(TestMmaLayout, Gfx942Intrinsics);
|
||||
TYPED_TEST(TestMmaLayout, Gfx942Intrinsics) { run_mma_layout_test<TypeParam>(); }
|
||||
@@ -0,0 +1,6 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_amdgcn_mma_layout.inc"
|
||||
TYPED_TEST_SUITE(TestMmaLayout, Gfx950Intrinsics);
|
||||
TYPED_TEST(TestMmaLayout, Gfx950Intrinsics) { run_mma_layout_test<TypeParam>(); }
|
||||
@@ -1,306 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_selector.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp"
|
||||
|
||||
#include <cassert>
|
||||
#include <cstdint>
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace ck_tile;
|
||||
|
||||
/**
|
||||
* @class RegisterMapTraits
|
||||
* @brief Traits class that defines tile_distribution_encoding for each MmaOp
|
||||
* @tparam MmaOp amdgcn_mma specialization
|
||||
*/
|
||||
template <typename MmaOp>
|
||||
struct RegisterMapTraits
|
||||
{
|
||||
static_assert(sizeof(MmaOp) == 0, "RegisterMapTraits requires a specialization");
|
||||
};
|
||||
|
||||
/**
|
||||
* @class RegisterMap
|
||||
* @brief Uses specialized RegisterMapTraits to get the encoding
|
||||
* @tparam MmaOp amdgcn_mma specialization
|
||||
*/
|
||||
template <typename MmaOp>
|
||||
struct RegisterMap
|
||||
{
|
||||
using Traits = RegisterMapTraits<MmaOp>;
|
||||
|
||||
using AMap = core::arch::mma::TileDistrEncRegMap<typename Traits::AWarpDstrEncoding>;
|
||||
using BMap = core::arch::mma::TileDistrEncRegMap<typename Traits::BWarpDstrEncoding>;
|
||||
using CMap = core::arch::mma::TileDistrEncRegMap<typename Traits::CWarpDstrEncoding>;
|
||||
|
||||
CK_TILE_HOST_DEVICE static auto Register2AMap(const uint32_t lane, const uint32_t vecIdx)
|
||||
{
|
||||
return AMap::calc_matrix_indices_from_lane_vector(static_cast<index_t>(lane),
|
||||
static_cast<index_t>(vecIdx));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static auto Register2BMap(const uint32_t lane, const uint32_t vecIdx)
|
||||
{
|
||||
return BMap::calc_matrix_indices_from_lane_vector(static_cast<index_t>(lane),
|
||||
static_cast<index_t>(vecIdx));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static auto Register2CMap(const uint32_t lane, const uint32_t vecIdx)
|
||||
{
|
||||
return CMap::calc_matrix_indices_from_lane_vector(static_cast<index_t>(lane),
|
||||
static_cast<index_t>(vecIdx));
|
||||
}
|
||||
};
|
||||
|
||||
// ====================== Specializations per target =====================
|
||||
|
||||
/**
|
||||
* @brief RegisterMapTraits for GFX12 WMMA 16x16x16_F16_F16_F32_GFX12
|
||||
*/
|
||||
template <typename CtrlFlags,
|
||||
typename CompilerTarget,
|
||||
ck_tile::core::arch::mma::MmaOpFamily OpFamily_>
|
||||
struct RegisterMapTraits<ck_tile::core::arch::mma::amdgcn_mma<
|
||||
ck_tile::fp16_t,
|
||||
ck_tile::fp16_t,
|
||||
ck_tile::fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
OpFamily_,
|
||||
ck_tile::core::arch::enable_if_target_family_gfx12_t<CompilerTarget>>>
|
||||
{
|
||||
using MmaOp = ck_tile::core::arch::mma::amdgcn_mma<ck_tile::fp16_t,
|
||||
ck_tile::fp16_t,
|
||||
ck_tile::fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
OpFamily_>;
|
||||
|
||||
static constexpr index_t AVecSize = vector_traits<typename MmaOp::AVecType>::vector_size;
|
||||
static constexpr index_t BVecSize = vector_traits<typename MmaOp::BVecType>::vector_size;
|
||||
static constexpr index_t CVecSize = vector_traits<typename MmaOp::CVecType>::vector_size;
|
||||
|
||||
using kABPs2RHssMajor = sequence<2, 1>;
|
||||
using kABPs2RHssMinor = sequence<1, 0>;
|
||||
using kABYs2RHsMajor = sequence<2, 2>;
|
||||
using kABYs2RHsMinor = sequence<0, 2>;
|
||||
using kCPs2RHssMajor = sequence<1, 2>;
|
||||
using kCPs2RHssMinor = sequence<1, 0>;
|
||||
using kCYs2RHsMajor = sequence<1, 1>;
|
||||
using kCYs2RHsMinor = sequence<0, 2>;
|
||||
|
||||
// TODO: remove these and fix constants in amdgcn_mma
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABK0PerLane = 1;
|
||||
static constexpr index_t kABKLane = 2;
|
||||
static constexpr index_t kABK1PerLane = 8;
|
||||
static constexpr index_t kCMLane = 2;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
static constexpr index_t kCM0PerLane = 1;
|
||||
static constexpr index_t kCM1PerLane = 8;
|
||||
|
||||
using AWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<kAMLane>, sequence<kABK0PerLane, kABKLane, kABK1PerLane>>, // <16>, <1, 2, 8>
|
||||
tuple<kABPs2RHssMajor>,
|
||||
tuple<kABPs2RHssMinor>,
|
||||
kABYs2RHsMajor,
|
||||
kABYs2RHsMinor>;
|
||||
|
||||
using BWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<kBNLane>, sequence<kABK0PerLane, kABKLane, kABK1PerLane>>, // <16>, <1, 2, 8>
|
||||
tuple<kABPs2RHssMajor>,
|
||||
tuple<kABPs2RHssMinor>,
|
||||
kABYs2RHsMajor,
|
||||
kABYs2RHsMinor>;
|
||||
|
||||
using CWarpDstrEncoding =
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<kCM0PerLane, kCMLane, kCM1PerLane>,
|
||||
sequence<kCNLane>>, // <1, 2, 8>, <16>
|
||||
tuple<kCPs2RHssMajor>,
|
||||
tuple<kCPs2RHssMinor>,
|
||||
kCYs2RHsMajor,
|
||||
kCYs2RHsMinor>;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief RegisterMapTraits for GFX9 MFMA 16x16x16_F16_F16_F32_GFX9
|
||||
*/
|
||||
template <typename CtrlFlags,
|
||||
typename CompilerTarget,
|
||||
ck_tile::core::arch::mma::MmaOpFamily OpFamily_>
|
||||
struct RegisterMapTraits<ck_tile::core::arch::mma::amdgcn_mma<
|
||||
ck_tile::fp16_t,
|
||||
ck_tile::fp16_t,
|
||||
ck_tile::fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
OpFamily_,
|
||||
ck_tile::core::arch::enable_if_target_family_gfx9_t<CompilerTarget>>>
|
||||
{
|
||||
using MmaOp = ck_tile::core::arch::mma::amdgcn_mma<ck_tile::fp16_t,
|
||||
ck_tile::fp16_t,
|
||||
ck_tile::fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
OpFamily_>;
|
||||
|
||||
static constexpr index_t AVecSize = vector_traits<typename MmaOp::AVecType>::vector_size;
|
||||
static constexpr index_t BVecSize = vector_traits<typename MmaOp::BVecType>::vector_size;
|
||||
static constexpr index_t CVecSize = vector_traits<typename MmaOp::CVecType>::vector_size;
|
||||
|
||||
using kABPs2RHssMajor = sequence<2, 1>;
|
||||
using kABPs2RHssMinor = sequence<0, 0>;
|
||||
using kABYs2RHsMajor = sequence<2>;
|
||||
using kABYs2RHsMinor = sequence<1>;
|
||||
using kCPs2RHssMajor = sequence<1, 2>;
|
||||
using kCPs2RHssMinor = sequence<0, 0>;
|
||||
using kCYs2RHsMajor = sequence<1>;
|
||||
using kCYs2RHsMinor = sequence<1>;
|
||||
|
||||
// TODO: remove these and fix constants in amdgcn_mma
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABKLane = 4;
|
||||
static constexpr index_t kABKPerLane = 4;
|
||||
static constexpr index_t kCMLane = 4;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
static constexpr index_t kCM0PerLane = 1;
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
using AWarpDstrEncoding =
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<kAMLane>, sequence<kABKLane, kABKPerLane>>,
|
||||
tuple<kABPs2RHssMajor>,
|
||||
tuple<kABPs2RHssMinor>,
|
||||
kABYs2RHsMajor,
|
||||
kABYs2RHsMinor>;
|
||||
|
||||
using BWarpDstrEncoding =
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<kBNLane>, sequence<kABKLane, kABKPerLane>>,
|
||||
tuple<kABPs2RHssMajor>,
|
||||
tuple<kABPs2RHssMinor>,
|
||||
kABYs2RHsMajor,
|
||||
kABYs2RHsMinor>;
|
||||
|
||||
using CWarpDstrEncoding =
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<kCMLane, kCM1PerLane>, sequence<kCNLane>>,
|
||||
tuple<kCPs2RHssMajor>,
|
||||
tuple<kCPs2RHssMinor>,
|
||||
kCYs2RHsMajor,
|
||||
kCYs2RHsMinor>;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief RegisterMapTraits for GFX11 WMMA 16x16x16_F16_F16_F32_GFX11
|
||||
*/
|
||||
template <typename CtrlFlags,
|
||||
typename CompilerTarget,
|
||||
ck_tile::core::arch::mma::MmaOpFamily OpFamily_>
|
||||
struct RegisterMapTraits<ck_tile::core::arch::mma::amdgcn_mma<
|
||||
ck_tile::fp16_t,
|
||||
ck_tile::fp16_t,
|
||||
ck_tile::fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
OpFamily_,
|
||||
ck_tile::core::arch::enable_if_target_family_gfx11_t<CompilerTarget>>>
|
||||
{
|
||||
using MmaOp = ck_tile::core::arch::mma::amdgcn_mma<ck_tile::fp16_t,
|
||||
ck_tile::fp16_t,
|
||||
ck_tile::fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
OpFamily_>;
|
||||
|
||||
static constexpr index_t AVecSize = vector_traits<typename MmaOp::AVecType>::vector_size;
|
||||
static constexpr index_t BVecSize = vector_traits<typename MmaOp::BVecType>::vector_size;
|
||||
static constexpr index_t CVecSize = vector_traits<typename MmaOp::CVecType>::vector_size;
|
||||
|
||||
using kABPs2RHssMajor = sequence<0, 1>;
|
||||
using kABPs2RHssMinor = sequence<0, 0>;
|
||||
using kABYs2RHsMajor = sequence<2>;
|
||||
using kABYs2RHsMinor = sequence<0>;
|
||||
using kCPs2RHssMajor = sequence<1, 2>;
|
||||
using kCPs2RHssMinor = sequence<1, 0>;
|
||||
using kCYs2RHsMajor = sequence<1>;
|
||||
using kCYs2RHsMinor = sequence<0>;
|
||||
|
||||
// TODO: remove these and fix constants in amdgcn_mma
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABK0PerLane = 1;
|
||||
static constexpr index_t kABKLane = 1;
|
||||
static constexpr index_t kABK1PerLane = 16;
|
||||
static constexpr index_t kCMLane = 2;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
static constexpr index_t kCM0PerLane = 8;
|
||||
static constexpr index_t kCM1PerLane = 1;
|
||||
|
||||
using AWarpDstrEncoding =
|
||||
tile_distribution_encoding<sequence<2>, // kRepeat
|
||||
tuple<sequence<kAMLane>, sequence<kABK1PerLane>>,
|
||||
tuple<kABPs2RHssMajor>,
|
||||
tuple<kABPs2RHssMinor>,
|
||||
kABYs2RHsMajor,
|
||||
kABYs2RHsMinor>;
|
||||
|
||||
using BWarpDstrEncoding =
|
||||
tile_distribution_encoding<sequence<2>, // kRepeat
|
||||
tuple<sequence<kBNLane>, sequence<kABK1PerLane>>,
|
||||
tuple<kABPs2RHssMajor>,
|
||||
tuple<kABPs2RHssMinor>,
|
||||
kABYs2RHsMajor,
|
||||
kABYs2RHsMinor>;
|
||||
|
||||
using CWarpDstrEncoding =
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<kCM0PerLane, kCMLane>, sequence<kCNLane>>,
|
||||
tuple<kCPs2RHssMajor>,
|
||||
tuple<kCPs2RHssMinor>,
|
||||
kCYs2RHsMajor,
|
||||
kCYs2RHsMinor>;
|
||||
};
|
||||
|
||||
// ========================================================================
|
||||
|
||||
} // namespace
|
||||
Reference in New Issue
Block a user