[CK Tile] Add Tile Distribution Encoding Calculator (#5515)

## Motivation

We want to be able to calculate TileDistributionEncodings describing
register mappings for any MmaOp. This is necessary for further
integration with CK Tile.

This MR adds a new struct TileDistrEncCalc, which takes an amdgcn_mma
type (MmaOp) and provides ABC warp distribution encodings for mapping
matrix fragment coordinates to register coordinates
(lane, vector item) and vice versa. It is able to take CTranpose,
Swizzle, and NumAccessA / NumAccessB template parameters for tweaking
the tile distributions. Swizzle modification will be implemented later.

The current implementation can deal with all intrinsic types and
block-hiding.

This MR also adds some additional static asserts and derived params
within amdgcn_mma_base, to enforce consistency and help calculate Tile
Distributions for block-hiding intrinsics.

An Example was added that uses the Tile Distr Enc Calc to calc and print
register layouts for Tile Distributions for some of our amdgcn_mma
structs. It also makes sure that the CTranspose modifier works as
intended.

Some additional gfx9 intrinsics were added to test block-hiding layouts
for the different types of C-block-hiding layouts.

The sparse intrinsic wrappers were updated according to Chris's recent
changes in another branch
(https://github.com/ROCm/rocm-libraries/pull/5508), which moved the
compression step outside of the intrinsic itself. This is necessary to
make sure that the Calculator can deal with this new interpretation of
the sparse intrinsics. I directly copied the new amdgcn structs from
Chris's branch and changed nothing else to avoid more complex merges in
the future. Note that this means I did not update a bunch of related
sparse code since that would be a lot, and therefore I disabled
test_amdgcn_sparse_mma for now.

The amdgcn_mma_layout test was refactored a bit:
- The old register mapping utility was removed and its use was replaced
by the new TileDistrEncCalc
- More tests were added to test layouts for different types of
block-hiding and sparse intrinsics
- The Selector method was removed and the tests were split up over
target architectures, with each target arch having a direct list of
amdgcn structs to be tested. This ensures that we force specific tests
on specific architectures and makes sure that the selector doesn't
quietly do some workarounds like creating compound intrinsics.

## Test Results

Layout tests based on calculated tile distribution encodings pass on all
architectures. Calculator works for all currently added amdgcn structs,
which includes different types of block-hiding and sparse intrinsics.
Printed layouts from new example verified by eye. CTranspose modifier
tested for large set of intrinsics.
This commit is contained in:
Kiefer van Teutem
2026-04-13 10:00:31 +02:00
committed by GitHub
parent 160bc1363e
commit 6cd016dde4
18 changed files with 623 additions and 665 deletions

View File

@@ -32,6 +32,7 @@
#include "ck_tile/core/arch/mma/sparse/sparse_transforms.hpp"
#include "ck_tile/core/arch/mma/sparse/wmma/selector.hpp"
#include "ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp"
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_calculator.hpp"
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp"
#include "ck_tile/core/arch/mma/wmma/wmma.hpp"
#include "ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp"

View File

@@ -4,6 +4,8 @@
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/wmma/wmma_traits.hpp"
#include "ck_tile/core/arch/mma/mfma/mfma_traits.hpp"
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
@@ -87,7 +89,7 @@ namespace ck_tile::core::arch::mma {
*
* (logical correctness). Applies to scale MFMA fp8, which due to the index matrix layout does not
* allow arbitrary K perms to simplify layouts. This means the layout can only properly be described
* with a Num Access value of at least 2.
* with a Num Access value which is a multiple of 2.
*
* (load / store manipulation). It seems like the load and store tile functions end up looking for
* the size of the smallest unmerged K dimension (K0) to determine how many elements should be
@@ -102,13 +104,16 @@ namespace ck_tile::core::arch::mma {
*
* -- CMPerLane --
* The number of M dim elements in each lane. In terms of unmerge sizes, it's equal to M0 * M2, i.e
* the product of the sizes of the outermost and innermost dimensions after a double M unmerge.
* the product of the sizes of the outermost and innermost dimensions after a double M unmerge. This
* does not count a potential increased M dimension size from block hiding. In this case, we have M
* = kCMBlock * M2 * M1 * M0 instead.
*
* -- CNumAccess --
* Same as A / B NumAccess but for the M dim (so M2), but the mid-level code doesn't care about this
* and will not try to request a specific value. Absolutely needed for logical correctness of
* register mappings since we can not perform arbitrary M permutations without messing up the A
* layout.
* layout. This does not count a potential increased M dimension size from block hiding. In this
* case, we have M = kCMBlock * M2 * M1 * M0 instead.
*/
/**
@@ -144,7 +149,7 @@ struct amdgcn_mma_base
using CDataType = CDataType_;
// Fragment (MmaTile) sizes, check description above.
static constexpr index_t kM = FragM; // M = M2 * M1 * M0
static constexpr index_t kM = FragM; // M = M2 * M1 * M0 (* kCMBlocks when block-hiding)
static constexpr index_t kN = FragN;
static constexpr index_t kK = FragK; // K = K2 * K1 * K0
@@ -157,15 +162,37 @@ struct amdgcn_mma_base
static constexpr index_t kCMPerLane = kCMPerLane_; // M2 * M0
static constexpr index_t kCMNumAccess = kCMNumAccess_; // M2
// K-dimension compression ratio for A matrix, always 2 for sparse intrinsics.
static constexpr index_t kCompressionRatio = (OpFamily == MmaOpFamily::SPARSE) ? 2 : 1;
// Layout checks
static_assert(kK % kABKPerLane == 0);
static_assert(kABKPerLane % kAKNumAccess == 0);
static_assert(kABKPerLane % kBKNumAccess == 0);
static_assert(kCMPerLane % kCMNumAccess == 0);
// Register types (derived)
static constexpr index_t WaveSize = WaveSize_;
static_assert((kM * kK * kARepeat) % WaveSize == 0);
static_assert((kM * kK * kARepeat) % (WaveSize * kCompressionRatio) == 0);
static_assert((kN * kK * kBRepeat) % WaveSize == 0);
static_assert((kM * kN) % WaveSize == 0);
using AVecType = ext_vector_t<ADataType, kM * kK * kARepeat / WaveSize>;
using AVecType = ext_vector_t<ADataType, kM * kK * kARepeat / WaveSize / kCompressionRatio>;
using BVecType = ext_vector_t<BDataType, kN * kK * kBRepeat / WaveSize>;
using CVecType = ext_vector_t<CDataType, kM * kN / WaveSize>;
// Block-hiding / repeat related traits (derived)
static_assert(kARepeat == kBRepeat || !std::is_same_v<OpType, WmmaOp>);
static_assert(kARepeat == 1 || kBRepeat == 1 || !std::is_same_v<OpType, MfmaOp>);
static constexpr index_t kCMBlocks = std::is_same_v<OpType, MfmaOp> ? kBRepeat : 1;
static constexpr index_t kCNBlocks = std::is_same_v<OpType, MfmaOp> ? kARepeat : 1;
static_assert(kM % (kCMBlocks * kCMPerLane) == 0);
static_assert(kN % kCNBlocks == 0);
// For the C matrix, the block dimension B is either put in the Vector dimension or the Lane
// dimension. We can tell which by checking if we get the right Vector size.
static constexpr bool CBlockDimInVecDim =
kCMBlocks * kCNBlocks * kCMPerLane == vector_traits<CVecType>::vector_size;
};
/**
@@ -181,6 +208,7 @@ struct Unsupported;
* @concept MmaOpI
* @brief Expresses the meta-data interface required for each MmaOp policy.
*/
// TODO: Make sure this actually matches amdgcn_mma.
template <typename MmaOp>
concept MmaOpI = requires(MmaOp op) {
// Requires an op context
@@ -194,7 +222,6 @@ concept MmaOpI = requires(MmaOp op) {
typename MmaOp::AVecType;
typename MmaOp::BVecType;
typename MmaOp::CVecType;
// Captures CK-specific layout properties
{ MmaOp::kABKPerLane } -> std::convertible_to<unsigned int>;
{ MmaOp::kAKNumAccess } -> std::convertible_to<unsigned int>;

View File

@@ -51,6 +51,82 @@ struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarg
}
};
template <typename CtrlFlags, typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 64u, 32u, 4u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx9_t<CompilerTarget>>
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 64u, 32u, 4u, 64u, 4, 1, 1, 1, 2, 16, 4, MfmaOp, MmaOpFamily::DENSE>
// clang-format on
{
CK_TILE_DEVICE static auto
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
{
return {__builtin_amdgcn_mfma_f32_32x32x4f16(aVec,
bVec,
cVec,
static_cast<int>(CtrlFlags::Cbsz),
static_cast<int>(CtrlFlags::Abid),
static_cast<int>(CtrlFlags::Blgp))};
}
};
template <typename CtrlFlags, typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 32u, 64u, 4u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx9_t<CompilerTarget>>
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 32u, 64u, 4u, 64u, 4, 1, 2, 1, 1, 16, 4, MfmaOp, MmaOpFamily::DENSE>
// clang-format on
{
CK_TILE_DEVICE static auto
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
{
return {__builtin_amdgcn_mfma_f32_32x32x4f16(aVec,
bVec,
cVec,
static_cast<int>(CtrlFlags::Cbsz),
static_cast<int>(CtrlFlags::Abid),
static_cast<int>(CtrlFlags::Blgp))};
}
};
template <typename CtrlFlags, typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 64u, 4u, 4u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx9_t<CompilerTarget>>
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 64u, 4u, 4u, 64u, 4, 1, 1, 1, 16, 4, 1, MfmaOp, MmaOpFamily::DENSE>
// clang-format on
{
CK_TILE_DEVICE static auto
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
{
return {__builtin_amdgcn_mfma_f32_4x4x4f16(aVec,
bVec,
cVec,
static_cast<int>(CtrlFlags::Cbsz),
static_cast<int>(CtrlFlags::Abid),
static_cast<int>(CtrlFlags::Blgp))};
}
};
template <typename CtrlFlags, typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 4u, 64u, 4u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx9_t<CompilerTarget>>
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 4u, 64u, 4u, 64u, 4, 1, 16, 1, 1, 4, 1, MfmaOp, MmaOpFamily::DENSE>
// clang-format on
{
CK_TILE_DEVICE static auto
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
{
return {__builtin_amdgcn_mfma_f32_4x4x4f16(aVec,
bVec,
cVec,
static_cast<int>(CtrlFlags::Cbsz),
static_cast<int>(CtrlFlags::Abid),
static_cast<int>(CtrlFlags::Blgp))};
}
};
/**
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for MFMA on GFX950 targets

View File

@@ -6,7 +6,6 @@
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse_traits.hpp"
namespace ck_tile::core::arch::mma {
@@ -31,25 +30,12 @@ struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarg
// clang-format on
{
CK_TILE_DEVICE static auto
exec(AVecType& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) -> CVecType
{
static constexpr index_t ABVecN = vector_traits<AVecType>::vector_size;
static constexpr index_t kCompressionRatio = 2;
static constexpr index_t CompressedSize = ABVecN / kCompressionRatio;
using AVecCompressed = ext_vector_t<fp16_t, CompressedSize>;
static_assert(CompressedSize == 4);
// TODO: Compressing A on-the-fly should be OK for now, but we need to validate
// and evaluate changing this to a transform at a higher level.
// aVec not being const can cause problems when running multiple intrinsics.
const uint32_t idx = ck_tile::compress_a_impl<fp16_t, CompressedSize>(aVec);
const AVecCompressed a_vec_pruned = {aVec[0], aVec[1], aVec[2], aVec[3]};
using namespace sparse::detail;
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
return {__builtin_amdgcn_smfmac_f32_16x16x32_f16(
a_vec_pruned, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
}
};

View File

@@ -43,18 +43,15 @@ struct BuiltinParams
template <SparseCompressionIndex Idx>
static constexpr BuiltinParams getBuiltinParams()
{
BuiltinParams params;
// TODO c++20: designated initializers
if constexpr(Idx == SparseCompressionIndex::FIRST)
{
params.UseFirstIndex = 1;
params.ByteIndexToOverride = 0;
return BuiltinParams{1, 0};
}
else
{
params.UseFirstIndex = 0;
params.ByteIndexToOverride = static_cast<int>(Idx);
return BuiltinParams{0, static_cast<int>(Idx)};
}
return params;
}
} // namespace sparse::detail

View File

@@ -7,7 +7,6 @@
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse_traits.hpp"
namespace ck_tile::core::arch::mma {
@@ -21,23 +20,9 @@ struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarg
// clang-format on
{
CK_TILE_DEVICE static auto
exec(AVecType& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) -> CVecType
{
static constexpr index_t ABVecN = vector_traits<AVecType>::vector_size;
static constexpr index_t kCompressionRatio = 2;
static constexpr index_t CompressedSize = ABVecN / kCompressionRatio;
using AVecCompressed = ext_vector_t<fp16_t, CompressedSize>;
static_assert(CompressedSize == 8);
// TODO: Compressing A on-the-fly should be OK for now, but we need to validate
// and evaluate changing this to a transform at a higher level.
// aVec not being const can cause problems when running multiple intrinsics.
const uint32_t idx = ck_tile::compress_a_impl<fp16_t, CompressedSize>(aVec);
const AVecCompressed a_vec_pruned = {
aVec[0], aVec[1], aVec[2], aVec[3], aVec[4], aVec[5], aVec[6], aVec[7]};
return {__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32(a_vec_pruned, bVec, cVec, idx)};
return {__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32(aVec, bVec, cVec, idx)};
}
};

View File

@@ -0,0 +1,114 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <algorithm>
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp"
namespace ck_tile::core::arch::mma {
/**
* @class TileDistrEncCalc
* @brief Given an MmaOp and modifiers, provides warp-level tile distribution encodings for mapping
* ABC matrix fragment coordinates to register coordinates (lane, vector item) and vice versa.
* @tparam MmaOp Intrinsic (amdgcn_mma).
* @tparam CTranspose Whether we are using CTranspose.
* @tparam SFactor Swizzle factor. Not implemented.
* @tparam AttrNumAccessA Requested NumAccess for the A matrix. Must be multiple of "fundamental"
* NumAccess for intrinsic. See details in amdgcn_mma.hpp.
* @tparam AttrNumAccessB Requested NumAccess for the B matrix.
*/
template <typename MmaOp,
bool CTranspose = false,
index_t SFactor = 1,
index_t AttrNumAccessA = MmaOp::kAKNumAccess,
index_t AttrNumAccessB = MmaOp::kBKNumAccess>
struct TileDistrEncCalc
{
private:
static constexpr index_t NumAccessA = std::max(MmaOp::kAKNumAccess, AttrNumAccessA);
static constexpr index_t NumAccessB = std::max(MmaOp::kBKNumAccess, AttrNumAccessB);
// We are free to choose any NumAccess value to manipulate the load / store behavior, unless the
// intrinsic fundamentally requires a base NumAccess factor for the layout to be correct.
static_assert(AttrNumAccessA % MmaOp::kAKNumAccess == 0,
"Requesting NumAccessA incompatible with builtin.");
static_assert(AttrNumAccessB % MmaOp::kBKNumAccess == 0,
"Requesting NumAccessB incompatible with builtin.");
static_assert(MmaOp::kABKPerLane % NumAccessA == 0);
static_assert(MmaOp::kABKPerLane % NumAccessB == 0);
static_assert(SFactor == 1, "Swizzle not implemented yet."); // TODO: Implement Swizzle.
template <index_t MajorDimSize, index_t Repeat, index_t NumAccess, index_t CompressionRatio = 1>
using ABWarpDstrEnc = tile_distribution_encoding<
sequence<Repeat>,
tuple<sequence<MajorDimSize>,
sequence<NumAccess,
MmaOp::kK / MmaOp::kABKPerLane,
MmaOp::kABKPerLane / NumAccess / CompressionRatio>>,
tuple<sequence<0, 2, 1>>,
tuple<sequence<0, 1, 0>>,
sequence<2, 2>,
sequence<0, 2>>;
static constexpr auto get_cwarp_dstr_encoding()
{
// We unmerge the M and N dimensions in the same way every time.
using MSubDims = sequence<MmaOp::kCMBlocks,
MmaOp::kCMNumAccess,
MmaOp::kM / MmaOp::kCMBlocks / MmaOp::kCMPerLane,
MmaOp::kCMPerLane / MmaOp::kCMNumAccess>;
using NSubDims = sequence<MmaOp::kCNBlocks, MmaOp::kN / MmaOp::kCNBlocks>;
// In case of CTranspose, all we do is swap the M and N dimension.
using MatDims =
std::conditional_t<CTranspose, tuple<NSubDims, MSubDims>, tuple<MSubDims, NSubDims>>;
constexpr int MInx = CTranspose ? 2 : 1;
constexpr int NInx = CTranspose ? 1 : 2;
// For MFMA intrinsics with blocks, the block dimensions might be in the Lane dim or in the
// Vec dim, so we get different merge orderings.
if constexpr(MmaOp::CBlockDimInVecDim)
{
return tile_distribution_encoding<sequence<1>,
MatDims,
tuple<sequence<MInx, NInx>>,
tuple<sequence<2, 1>>,
sequence<MInx, NInx, MInx, MInx>,
sequence<0, 0, 1, 3>>{};
}
else
{
return tile_distribution_encoding<sequence<1>,
MatDims,
tuple<sequence<MInx, NInx, MInx, NInx>>,
tuple<sequence<0, 0, 2, 1>>,
sequence<MInx, MInx>,
sequence<1, 3>>{};
}
}
using AEnc_ = ABWarpDstrEnc<MmaOp::kM, MmaOp::kARepeat, NumAccessA, MmaOp::kCompressionRatio>;
using BEnc_ = ABWarpDstrEnc<MmaOp::kN, MmaOp::kBRepeat, NumAccessB>;
public:
// When using CTranspose, the A and B matrices are swapped.
using AWarpDstrEncoding = std::conditional_t<CTranspose, BEnc_, AEnc_>;
using BWarpDstrEncoding = std::conditional_t<CTranspose, AEnc_, BEnc_>;
using CWarpDstrEncoding = decltype(get_cwarp_dstr_encoding());
// Some additional consistency checks
static_assert(TileDistrEncRegMap<AWarpDstrEncoding>::num_lanes == MmaOp::WaveSize);
static_assert(TileDistrEncRegMap<BWarpDstrEncoding>::num_lanes == MmaOp::WaveSize);
static_assert(TileDistrEncRegMap<CWarpDstrEncoding>::num_lanes == MmaOp::WaveSize);
static_assert(TileDistrEncRegMap<AWarpDstrEncoding>::num_vector_items ==
vector_traits<typename MmaOp::AVecType>::vector_size);
static_assert(TileDistrEncRegMap<BWarpDstrEncoding>::num_vector_items ==
vector_traits<typename MmaOp::BVecType>::vector_size);
static_assert(TileDistrEncRegMap<CWarpDstrEncoding>::num_vector_items ==
vector_traits<typename MmaOp::CVecType>::vector_size);
};
} // namespace ck_tile::core::arch::mma