[CK Tile] Add Tile Distribution Encoding Calculator (#5515)

## Motivation

We want to be able to calculate TileDistributionEncodings describing
register mappings for any MmaOp. This is necessary for further
integration with CK Tile.

This MR adds a new struct TileDistrEncCalc, which takes an amdgcn_mma
type (MmaOp) and provides ABC warp distribution encodings for mapping
matrix fragment coordinates to register coordinates
(lane, vector item) and vice versa. It is able to take CTranpose,
Swizzle, and NumAccessA / NumAccessB template parameters for tweaking
the tile distributions. Swizzle modification will be implemented later.

The current implementation can deal with all intrinsic types and
block-hiding.

This MR also adds some additional static asserts and derived params
within amdgcn_mma_base, to enforce consistency and help calculate Tile
Distributions for block-hiding intrinsics.

An Example was added that uses the Tile Distr Enc Calc to calc and print
register layouts for Tile Distributions for some of our amdgcn_mma
structs. It also makes sure that the CTranspose modifier works as
intended.

Some additional gfx9 intrinsics were added to test block-hiding layouts
for the different types of C-block-hiding layouts.

The sparse intrinsic wrappers were updated according to Chris's recent
changes in another branch
(https://github.com/ROCm/rocm-libraries/pull/5508), which moved the
compression step outside of the intrinsic itself. This is necessary to
make sure that the Calculator can deal with this new interpretation of
the sparse intrinsics. I directly copied the new amdgcn structs from
Chris's branch and changed nothing else to avoid more complex merges in
the future. Note that this means I did not update a bunch of related
sparse code since that would be a lot, and therefore I disabled
test_amdgcn_sparse_mma for now.

The amdgcn_mma_layout test was refactored a bit:
- The old register mapping utility was removed and its use was replaced
by the new TileDistrEncCalc
- More tests were added to test layouts for different types of
block-hiding and sparse intrinsics
- The Selector method was removed and the tests were split up over
target architectures, with each target arch having a direct list of
amdgcn structs to be tested. This ensures that we force specific tests
on specific architectures and makes sure that the selector doesn't
quietly do some workarounds like creating compound intrinsics.

## Test Results

Layout tests based on calculated tile distribution encodings pass on all
architectures. Calculator works for all currently added amdgcn structs,
which includes different types of block-hiding and sparse intrinsics.
Printed layouts from new example verified by eye. CTranspose modifier
tested for large set of intrinsics.
This commit is contained in:
Kiefer van Teutem
2026-04-13 10:00:31 +02:00
committed by GitHub
parent 160bc1363e
commit 6cd016dde4
18 changed files with 623 additions and 665 deletions

View File

@@ -2,3 +2,4 @@
# SPDX-License-Identifier: MIT
add_executable(tile_example_tile_distr_enc_reg_map example_tile_distr_enc_reg_map.cpp)
add_executable(tile_example_tile_distr_enc_calc example_tile_distr_enc_calc.cpp)

View File

@@ -0,0 +1,93 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdio>
#include <type_traits>
#include <tuple>
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp"
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_calculator.hpp"
#include "ck_tile/core/container/tuple.hpp"
using namespace ck_tile;
using namespace ck_tile::core::arch;
using namespace mma;
using F16 = fp16_t;
using F32 = fp32_t;
using Target908 = decltype(make_amdgcn_gfx9_target<amdgcn_target_id::GFX908>());
using Target950 = decltype(make_amdgcn_gfx9_target<amdgcn_target_id::GFX950>());
using Target11 = decltype(make_amdgcn_gfx11_target<amdgcn_target_id::GFX1100>());
using Target12 = decltype(make_amdgcn_gfx12_target<amdgcn_target_id::GFX1201>());
template <typename MmaOp>
int check_tile_distr_enc()
{
using AEnc = typename TileDistrEncCalc<MmaOp>::AWarpDstrEncoding;
using BEnc = typename TileDistrEncCalc<MmaOp>::BWarpDstrEncoding;
using CEnc = typename TileDistrEncCalc<MmaOp>::CWarpDstrEncoding;
TileDistrEncRegMap<AEnc>::print();
TileDistrEncRegMap<BEnc>::print();
TileDistrEncRegMap<CEnc>::print();
// The only thing we check here is that CTranspose works as expected.
using AEncTransp = typename TileDistrEncCalc<MmaOp, true>::AWarpDstrEncoding;
using BEncTransp = typename TileDistrEncCalc<MmaOp, true>::BWarpDstrEncoding;
using CEncTransp = typename TileDistrEncCalc<MmaOp, true>::CWarpDstrEncoding;
// When using TransposeC, the A and B matrix layouts should be swapped.
static_assert(std::is_same<AEncTransp, BEnc>());
static_assert(std::is_same<BEncTransp, AEnc>());
// Make sure the C matrix layout is transposed in the CTranspose case.
int err = 0;
for(index_t lane = 0; lane < TileDistrEncRegMap<CEnc>::num_lanes; lane++)
{
for(index_t vec = 0; vec < TileDistrEncRegMap<CEnc>::num_vector_items; vec++)
{
auto coords = TileDistrEncRegMap<CEnc>::calc_matrix_indices_from_lane_vector(lane, vec);
auto coords_transp =
TileDistrEncRegMap<CEncTransp>::calc_matrix_indices_from_lane_vector(lane, vec);
if(coords[0] != coords_transp[1] || coords[1] != coords_transp[0])
{
err = 1;
printf("\033[31mLane %2d vec %2d maps to C matrix coords %2d %2d and transposed C "
"matrix coords %2d %2d, inconsistent!\033[0m\n",
lane,
vec,
coords[0],
coords[1],
coords_transp[0],
coords_transp[1]);
}
}
}
return err;
}
// List of intrinsics to test.
// clang-format off
using Intrinsics = ck_tile::tuple<
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_16x16x16f16
amdgcn_mma<F16, F16, F32, 64u, 32u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_32x32x4f16
amdgcn_mma<F16, F16, F32, 32u, 64u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_32x32x4f16
amdgcn_mma<F16, F16, F32, 64u, 4u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_4x4x4f16
amdgcn_mma<F16, F16, F32, 4u, 64u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_4x4x4f16
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, DefaultMfmaCtrlFlags, Target950, MmaOpFamily::DENSE>, // mfma_f32_16x16x32_f16
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags<F16, F16, F32>, Target11, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_f16_w32
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags<F16, F16, F32>, Target12, MmaOpFamily::DENSE> // wmma_f32_16x16x16_f16_w32_gfx12
>;
// clang-format on
int main()
{
int err = 0;
static_for<0, Intrinsics::size(), 1>{}([&](auto i) {
using MmaOp = std::tuple_element_t<i.value, Intrinsics>;
err |= check_tile_distr_enc<MmaOp>();
});
return err;
}