[rocm-libraries] ROCm/rocm-libraries#4495 (commit 5664eb0)

Adding layout test for amdgcn_mma structs

## Motivation

Currently, the test suite for `amdgcn_mma` focuses on the design (e.g.
choosing the correct specialization based on SFINAE) and a single live
test that checks if selected MmaOp runs. This PR adds a simplified GEMM
test kernel that checks the exact layout of the selected MmaOp.

## Technical Details

The test in `test_amdgcn_mma_layout.cpp` launches MxKxN test cases (one
per block), where each case:
1. Constructs A and B tensors on a device with a single 1 at A(m,k) and
B(k,n) (rest is all 0s)
  2. Executes the MMA intrinsic.
  3.  Checks if C has the "1" on the excpeted position.

For the MMA instrinsic, it pulls a Mma op from amdgcn_mma specialization
based on a given input (tile dimension, data types).

Note 1: As a helper, in `test_amdgcn_mma_layout_util.hpp` we add
register map for a given amdgcn_mma specialization. Register mapping is
currently based on the `tile_distribution_encoding`.

Note 2: Everything is added to the test suite, no additions to the
actual `amdgcn_mma` structs. All the extra information that is needed,
but not yet provided by `amdgcn_mma` structs, is added as a boilerplate
to the header. TODO: Rebase this PR on top of the `amdgcn_mma` refactor
or clean it up after merge.

## Test Plan

This PR solely adds a new test to the existing code.

## Test Result

Tests pass.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Wojciech Laskowski
2026-03-06 16:22:16 +00:00
committed by assistant-librarian[bot]
parent b0c13f3124
commit b80e41f3bc
3 changed files with 586 additions and 0 deletions

View File

@@ -17,3 +17,11 @@ if(GPU_TARGETS MATCHES "gfx9")
else()
message(DEBUG "Skipping ck_tile_gemm tests for current target")
endif()
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
add_gtest_executable(test_amdgcn_mma_layout test_amdgcn_mma_layout.cpp)
target_compile_options(test_amdgcn_mma_layout PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
else()
message(DEBUG "Skipping gfx9|gfx11|gfx12 mma layout validation tests for current target")
endif()

View File

@@ -0,0 +1,286 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <hip/hip_runtime.h>
#include <gtest/gtest.h>
#include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/device_memory.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/utility/env.hpp"
#include "test/ck_tile/core/arch/mma/test_amdgcn_mma_layout_util.hpp"
#include <algorithm>
#include <array>
#include <cmath>
#include <cstdio>
#include <cstdint>
#include <cstdlib>
#include <vector>
#include <cstddef>
#include <string>
#include <tuple>
namespace ck = ck_tile;
namespace mma = ck_tile::core::arch::mma;
// MMA register layout validation test for amdgcn_mma structs.
//
// Strategy: for every (m, k, n) triple in the tile, the test constructs a pair of input tensors
// A and B that contain exactly one non-zero element each, placed so that their product
// contributes to a single output element C(m, n):
//
// A (M x K) B (K x N) C = A * B (M x N)
// . . . . . . . . . . . . . . . . . . . . . . . .
// . . . . . . . . . . . . . . . . . . . . . . . .
// . . . 1 . . . . . . . . . . . . . . . . . . . .
// . . . . . . . . . . . 1 . . . . . . . . . 1 . .
// . . . . . . . . . . . . . . . . . . . . . . . .
// A(m,k) = 1 B(k,n) = 1 C(m,n) = 1
//
// The kernel uses RegisterMap to scatter A and B into the correct (lane, vecIdx) positions
// of the MMA fragment registers, executes the intrinsic, then uses RegisterMap again to
// gather back into C matrix. The position of "1" in C is checked against the expected (m, n)
// location.
namespace {
/**
* @class MmaLayoutTestKernel
* @brief Device kernel that performs C = AB using a given Mma op
*
* @tparam ADataType Data type of tensor A elements
* @tparam BDataType Data type of tensor B elements
* @tparam CDataType Data type of tensor C elements
* @tparam BlockM M-dimension of the MMA tile
* @tparam BlockN N-dimension of the MMA tile
* @tparam BlockK K-dimension of the MMA tile
* @tparam BlockSize HIP block size
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
uint32_t BlockM,
uint32_t BlockN,
uint32_t BlockK,
uint32_t BlockSize>
struct MmaLayoutTestKernel
{
static constexpr int kBlockSize = BlockSize;
__device__ void operator()(uint32_t* error_flags) const
{
using Selector =
mma::MmaDefaultSelector<ADataType,
BDataType,
CDataType,
BlockM,
BlockN,
BlockK,
decltype(ck_tile::core::arch::get_compiler_target())>;
using MmaOp = typename Selector::SelectedOp;
using MmaTraits = mma::MmaOpTraits<MmaOp>;
if constexpr(MmaTraits::IsSupported)
{
using AVecType = typename MmaTraits::AVecType;
using BVecType = typename MmaTraits::BVecType;
using CVecType = typename MmaTraits::CVecType;
constexpr uint32_t a_vec_size = vector_traits<AVecType>::vector_size;
constexpr uint32_t b_vec_size = vector_traits<BVecType>::vector_size;
constexpr uint32_t c_vec_size = vector_traits<CVecType>::vector_size;
const uint32_t lane = threadIdx.x;
AVecType a_frag{};
BVecType b_frag{};
CVecType c_frag{};
// get (m, k, n), where "1" should be placed for this block
const uint32_t case_idx = static_cast<uint32_t>(blockIdx.x);
const uint32_t m = case_idx / (MmaTraits::BlockK * MmaTraits::BlockN);
const uint32_t k = (case_idx / MmaTraits::BlockN) % MmaTraits::BlockK;
const uint32_t n = case_idx % MmaTraits::BlockN;
// place a single "1" in A/B fragments using (lane, vecIdx) -> (row, col) mapping
for(uint32_t v = 0; v < a_vec_size; ++v)
{
auto a_coords = RegisterMap<MmaOp>::Register2AMap(lane, v);
if(static_cast<uint32_t>(a_coords[0]) == m &&
static_cast<uint32_t>(a_coords[1]) == k)
{
a_frag[v] = static_cast<ADataType>(1);
}
}
for(uint32_t v = 0; v < b_vec_size; ++v)
{
auto b_coords = RegisterMap<MmaOp>::Register2BMap(lane, v);
if(static_cast<uint32_t>(b_coords[0]) == n &&
static_cast<uint32_t>(b_coords[1]) == k)
{
b_frag[v] = static_cast<BDataType>(1);
}
}
c_frag = MmaOp::exec(a_frag, b_frag, c_frag);
uint32_t err = 0;
const CDataType tol = static_cast<CDataType>(
1.0e-1f); // TODO: this tolerance might not be suitable for all data types and
// should be revisited if we add more configurations
for(uint32_t v = 0; v < c_vec_size; ++v)
{
auto c_coords = RegisterMap<MmaOp>::Register2CMap(lane, v);
const uint32_t i = static_cast<uint32_t>(c_coords[0]);
const uint32_t j = static_cast<uint32_t>(c_coords[1]);
const CDataType expected =
(i == m && j == n) ? static_cast<CDataType>(1) : static_cast<CDataType>(0);
const CDataType value = static_cast<CDataType>(c_frag[v]);
if(fabsf(static_cast<float>(value - expected)) > static_cast<float>(tol))
{
err = 1;
}
}
const uint32_t any_err = __any(err);
if(threadIdx.x == 0)
{
error_flags[case_idx] = any_err;
}
}
}
};
/**
* @brief Test driver: runs the test for a given MMA configuration.
*
* The testlaunches (mkn) test cases (one per block) to check all possible positions of the "1" in
* the A/B tensors.
* 1. Constructs A and B tensors with a single 1 at A(m,k) and B(k,n).
* 2. Executes MMA intrinsic to compute C tensor.
* 3. Checks if C has the 1 in the expected position.
*
* @tparam Selector Selector for the Mma operation
* @return true if the test ran on hardware; false if skipped (no device or unsupported)
*/
template <typename Selector>
bool run_mma_layout_test()
{
using MmaOp = typename Selector::SelectedOp;
using MmaTraits = mma::MmaOpTraits<MmaOp>;
using ADataType = typename MmaTraits::ADataType;
using BDataType = typename MmaTraits::BDataType;
using CDataType = typename MmaTraits::CDataType;
constexpr uint32_t BlockM = MmaTraits::BlockM;
constexpr uint32_t BlockN = MmaTraits::BlockN;
constexpr uint32_t BlockK = MmaTraits::BlockK;
constexpr auto selector_target_id = MmaTraits::CompilerTarget::TARGET_ID;
constexpr auto selector_wave_size = MmaTraits::CompilerTarget::WAVE_SIZE_ID;
int device_count = 0;
hipDevice_t device{};
HIP_CHECK_ERROR(hipGetDevice(&device));
HIP_CHECK_ERROR(hipGetDeviceCount(&device_count));
hipDeviceProp_t props{};
HIP_CHECK_ERROR(hipGetDeviceProperties(&props, device));
const auto runtime_target =
ck_tile::core::arch::hip_device_prop_gcn_arch_name_to_amdgcn_target_id(props.gcnArchName);
const bool has_device = device_count > 0;
if(!has_device || runtime_target == ck_tile::core::arch::amdgcn_target_id::HOST ||
runtime_target != selector_target_id ||
props.warpSize != static_cast<int>(selector_wave_size))
{
return false;
}
constexpr uint32_t total_cases = BlockM * BlockK * BlockN;
ck_tile::DeviceMem d_errors(total_cases * sizeof(uint32_t));
std::vector<uint32_t> h_errors(total_cases, 0u);
auto* d_error_ptr = static_cast<uint32_t*>(d_errors.GetDeviceBuffer());
std::ignore = hipGetLastError();
using Kernel = MmaLayoutTestKernel<ADataType,
BDataType,
CDataType,
BlockM,
BlockN,
BlockK,
static_cast<int>(selector_wave_size)>;
std::ignore =
ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1},
ck_tile::make_kernel(Kernel{},
dim3(total_cases),
dim3(static_cast<int>(selector_wave_size)),
0,
d_error_ptr));
HIP_CHECK_ERROR(hipMemcpyAsync(
h_errors.data(), d_error_ptr, d_errors.GetBufferSize(), hipMemcpyDeviceToHost));
HIP_CHECK_ERROR(hipStreamSynchronize(nullptr));
for(uint32_t case_idx = 0; case_idx < total_cases; ++case_idx)
{
const uint32_t m = case_idx / (BlockK * BlockN);
const uint32_t k = (case_idx / BlockN) % BlockK;
const uint32_t n = case_idx % BlockN;
EXPECT_EQ(h_errors[case_idx], 0u) << "Mismatch for m=" << m << " k=" << k << " n=" << n;
}
return true;
}
} // namespace
// ==================== Test configurations per target ====================
// TODO: currently we have only 1 specific target per test. This should be revisited to enable all
// the targets within the family (gfx12, gfx11, gfx9)
using MmaGfx1201CompilerTarget = decltype(ck_tile::core::arch::make_amdgcn_gfx12_target<
ck_tile::core::arch::amdgcn_target_id::GFX1201>());
using MmaGfx90aCompilerTarget = decltype(ck_tile::core::arch::make_amdgcn_gfx9_target<
ck_tile::core::arch::amdgcn_target_id::GFX90A>());
using MmaGfx1100CompilerTarget = decltype(ck_tile::core::arch::make_amdgcn_gfx11_target<
ck_tile::core::arch::amdgcn_target_id::GFX1100>());
using MmaGfx1201Selector = mma::
MmaDefaultSelector<ck::fp16_t, ck::fp16_t, ck::fp32_t, 16u, 16u, 16u, MmaGfx1201CompilerTarget>;
using MmaGfx90aSelector = mma::
MmaDefaultSelector<ck::fp16_t, ck::fp16_t, ck::fp32_t, 16u, 16u, 16u, MmaGfx90aCompilerTarget>;
using MmaGfx1100Selector = mma::
MmaDefaultSelector<ck::fp16_t, ck::fp16_t, ck::fp32_t, 16u, 16u, 16u, MmaGfx1100CompilerTarget>;
// clang-format off
using KernelTypes = ::testing::Types<
MmaGfx1201Selector,
MmaGfx90aSelector,
MmaGfx1100Selector
>;
// clang-format on
template <typename Selector>
class TestMmaLayout : public ::testing::Test
{
};
TYPED_TEST_SUITE(TestMmaLayout, KernelTypes);
TYPED_TEST(TestMmaLayout, Mma_16x16x16_F16_F16_F32)
{
bool executed = run_mma_layout_test<TypeParam>();
if(!executed)
{
GTEST_SKIP() << "No supported HIP device found. Skipping test.";
}
}

View File

@@ -0,0 +1,292 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
#include "ck_tile/core/arch/mma/mma_traits.hpp"
#include "ck_tile/core/arch/mma/mma_selector.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp"
#include <cassert>
#include <cstdint>
namespace {
using namespace ck_tile;
/**
* @class RegisterMapTraits
* @brief Traits class that defines tile_distribution_encoding for each MmaOp
* @tparam MmaOp amdgcn_mma specialization
*/
template <typename MmaOp>
struct RegisterMapTraits
{
static_assert(sizeof(MmaOp) == 0, "RegisterMapTraits requires a specialization");
};
/**
* @class RegisterMap
* @brief Uses specialized RegisterMapTraits to get the encoding
* @tparam MmaOp amdgcn_mma specialization
*/
template <typename MmaOp>
struct RegisterMap
{
using Traits = RegisterMapTraits<MmaOp>;
using AMap = core::arch::mma::TileDistrEncRegMap<typename Traits::AWarpDstrEncoding>;
using BMap = core::arch::mma::TileDistrEncRegMap<typename Traits::BWarpDstrEncoding>;
using CMap = core::arch::mma::TileDistrEncRegMap<typename Traits::CWarpDstrEncoding>;
CK_TILE_HOST_DEVICE static auto Register2AMap(const uint32_t lane, const uint32_t vecIdx)
{
return AMap::calc_matrix_indices_from_lane_vector(static_cast<index_t>(lane),
static_cast<index_t>(vecIdx));
}
CK_TILE_HOST_DEVICE static auto Register2BMap(const uint32_t lane, const uint32_t vecIdx)
{
return BMap::calc_matrix_indices_from_lane_vector(static_cast<index_t>(lane),
static_cast<index_t>(vecIdx));
}
CK_TILE_HOST_DEVICE static auto Register2CMap(const uint32_t lane, const uint32_t vecIdx)
{
return CMap::calc_matrix_indices_from_lane_vector(static_cast<index_t>(lane),
static_cast<index_t>(vecIdx));
}
};
// ====================== Specializations per target =====================
/**
* @brief RegisterMapTraits for GFX12 WMMA 16x16x16_F16_F16_F32_GFX12
*/
template <typename CtrlFlags, typename CompilerTarget>
struct RegisterMapTraits<ck_tile::core::arch::mma::amdgcn_mma<
ck_tile::fp16_t,
ck_tile::fp16_t,
ck_tile::fp32_t,
16u,
16u,
16u,
CtrlFlags,
CompilerTarget,
ck_tile::core::arch::enable_if_target_family_gfx12_t<CompilerTarget>>>
{
using MmaOp = ck_tile::core::arch::mma::amdgcn_mma<ck_tile::fp16_t,
ck_tile::fp16_t,
ck_tile::fp32_t,
16u,
16u,
16u,
CtrlFlags,
CompilerTarget>;
using MmaTraits = ck_tile::core::arch::mma::MmaOpTraits<MmaOp>;
static constexpr index_t WaveSize =
static_cast<index_t>(MmaTraits::CompilerTarget::WAVE_SIZE_ID);
static constexpr index_t AVecSize = vector_traits<typename MmaTraits::AVecType>::vector_size;
static constexpr index_t BVecSize = vector_traits<typename MmaTraits::BVecType>::vector_size;
static constexpr index_t CVecSize = vector_traits<typename MmaTraits::CVecType>::vector_size;
using kABPs2RHssMajor = sequence<2, 1>;
using kABPs2RHssMinor = sequence<1, 0>;
using kABYs2RHsMajor = sequence<2, 2>;
using kABYs2RHsMinor = sequence<0, 2>;
using kCPs2RHssMajor = sequence<1, 2>;
using kCPs2RHssMinor = sequence<1, 0>;
using kCYs2RHsMajor = sequence<1, 1>;
using kCYs2RHsMinor = sequence<0, 2>;
// TODO: remove these and fix constants in amdgcn_mma
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 16;
static constexpr index_t kBNLane = 16;
static constexpr index_t kABK0PerLane = 1;
static constexpr index_t kABKLane = 2;
static constexpr index_t kABK1PerLane = 8;
static constexpr index_t kCMLane = 2;
static constexpr index_t kCNLane = 16;
static constexpr index_t kCM0PerLane = 1;
static constexpr index_t kCM1PerLane = 8;
using AWarpDstrEncoding = tile_distribution_encoding<
sequence<1>,
tuple<sequence<kAMLane>, sequence<kABK0PerLane, kABKLane, kABK1PerLane>>, // <16>, <1, 2, 8>
tuple<kABPs2RHssMajor>,
tuple<kABPs2RHssMinor>,
kABYs2RHsMajor,
kABYs2RHsMinor>;
using BWarpDstrEncoding = tile_distribution_encoding<
sequence<1>,
tuple<sequence<kBNLane>, sequence<kABK0PerLane, kABKLane, kABK1PerLane>>, // <16>, <1, 2, 8>
tuple<kABPs2RHssMajor>,
tuple<kABPs2RHssMinor>,
kABYs2RHsMajor,
kABYs2RHsMinor>;
using CWarpDstrEncoding =
tile_distribution_encoding<sequence<1>,
tuple<sequence<kCM0PerLane, kCMLane, kCM1PerLane>,
sequence<kCNLane>>, // <1, 2, 8>, <16>
tuple<kCPs2RHssMajor>,
tuple<kCPs2RHssMinor>,
kCYs2RHsMajor,
kCYs2RHsMinor>;
};
/**
* @brief RegisterMapTraits for GFX9 MFMA 16x16x16_F16_F16_F32_GFX9
*/
template <typename CtrlFlags, typename CompilerTarget>
struct RegisterMapTraits<ck_tile::core::arch::mma::amdgcn_mma<
ck_tile::fp16_t,
ck_tile::fp16_t,
ck_tile::fp32_t,
16u,
16u,
16u,
CtrlFlags,
CompilerTarget,
ck_tile::core::arch::enable_if_target_family_gfx9_t<CompilerTarget>>>
{
using MmaOp = ck_tile::core::arch::mma::amdgcn_mma<ck_tile::fp16_t,
ck_tile::fp16_t,
ck_tile::fp32_t,
16u,
16u,
16u,
CtrlFlags,
CompilerTarget>;
using MmaTraits = ck_tile::core::arch::mma::MmaOpTraits<MmaOp>;
static constexpr index_t WaveSize =
static_cast<index_t>(MmaTraits::CompilerTarget::WAVE_SIZE_ID);
static constexpr index_t AVecSize = vector_traits<typename MmaTraits::AVecType>::vector_size;
static constexpr index_t BVecSize = vector_traits<typename MmaTraits::BVecType>::vector_size;
static constexpr index_t CVecSize = vector_traits<typename MmaTraits::CVecType>::vector_size;
using kABPs2RHssMajor = sequence<2, 1>;
using kABPs2RHssMinor = sequence<0, 0>;
using kABYs2RHsMajor = sequence<2>;
using kABYs2RHsMinor = sequence<1>;
using kCPs2RHssMajor = sequence<1, 2>;
using kCPs2RHssMinor = sequence<0, 0>;
using kCYs2RHsMajor = sequence<1>;
using kCYs2RHsMinor = sequence<1>;
using AWarpDstrEncoding = tile_distribution_encoding<
sequence<1>,
tuple<sequence<MmaOp::kAMLane>, sequence<MmaOp::kABKLane, MmaOp::kABKPerLane>>,
tuple<kABPs2RHssMajor>,
tuple<kABPs2RHssMinor>,
kABYs2RHsMajor,
kABYs2RHsMinor>;
using BWarpDstrEncoding = tile_distribution_encoding<
sequence<1>,
tuple<sequence<MmaOp::kBNLane>, sequence<MmaOp::kABKLane, MmaOp::kABKPerLane>>,
tuple<kABPs2RHssMajor>,
tuple<kABPs2RHssMinor>,
kABYs2RHsMajor,
kABYs2RHsMinor>;
using CWarpDstrEncoding = tile_distribution_encoding<
sequence<1>,
tuple<sequence<MmaOp::kCMLane, MmaOp::kCM1PerLane>, sequence<MmaOp::kCNLane>>,
tuple<kCPs2RHssMajor>,
tuple<kCPs2RHssMinor>,
kCYs2RHsMajor,
kCYs2RHsMinor>;
};
/**
* @brief RegisterMapTraits for GFX11 WMMA 16x16x16_F16_F16_F32_GFX11
*/
template <typename CtrlFlags, typename CompilerTarget>
struct RegisterMapTraits<ck_tile::core::arch::mma::amdgcn_mma<
ck_tile::fp16_t,
ck_tile::fp16_t,
ck_tile::fp32_t,
16u,
16u,
16u,
CtrlFlags,
CompilerTarget,
ck_tile::core::arch::enable_if_target_family_gfx11_t<CompilerTarget>>>
{
using MmaOp = ck_tile::core::arch::mma::amdgcn_mma<ck_tile::fp16_t,
ck_tile::fp16_t,
ck_tile::fp32_t,
16u,
16u,
16u,
CtrlFlags,
CompilerTarget>;
using MmaTraits = ck_tile::core::arch::mma::MmaOpTraits<MmaOp>;
static constexpr index_t WaveSize =
static_cast<index_t>(MmaTraits::CompilerTarget::WAVE_SIZE_ID);
static constexpr index_t AVecSize = vector_traits<typename MmaTraits::AVecType>::vector_size;
static constexpr index_t BVecSize = vector_traits<typename MmaTraits::BVecType>::vector_size;
static constexpr index_t CVecSize = vector_traits<typename MmaTraits::CVecType>::vector_size;
using kABPs2RHssMajor = sequence<0, 1>;
using kABPs2RHssMinor = sequence<0, 0>;
using kABYs2RHsMajor = sequence<2>;
using kABYs2RHsMinor = sequence<0>;
using kCPs2RHssMajor = sequence<1, 2>;
using kCPs2RHssMinor = sequence<1, 0>;
using kCYs2RHsMajor = sequence<1>;
using kCYs2RHsMinor = sequence<0>;
// TODO: remove these and fix constants in amdgcn_mma
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 16;
static constexpr index_t kBNLane = 16;
static constexpr index_t kABK0PerLane = 1;
static constexpr index_t kABKLane = 1;
static constexpr index_t kABK1PerLane = 16;
static constexpr index_t kCMLane = 2;
static constexpr index_t kCNLane = 16;
static constexpr index_t kCM0PerLane = 8;
static constexpr index_t kCM1PerLane = 1;
using AWarpDstrEncoding =
tile_distribution_encoding<sequence<2>, // kRepeat
tuple<sequence<kAMLane>, sequence<kABK1PerLane>>,
tuple<kABPs2RHssMajor>,
tuple<kABPs2RHssMinor>,
kABYs2RHsMajor,
kABYs2RHsMinor>;
using BWarpDstrEncoding =
tile_distribution_encoding<sequence<2>, // kRepeat
tuple<sequence<kBNLane>, sequence<kABK1PerLane>>,
tuple<kABPs2RHssMajor>,
tuple<kABPs2RHssMinor>,
kABYs2RHsMajor,
kABYs2RHsMinor>;
using CWarpDstrEncoding =
tile_distribution_encoding<sequence<1>,
tuple<sequence<kCM0PerLane, kCMLane>, sequence<kCNLane>>,
tuple<kCPs2RHssMajor>,
tuple<kCPs2RHssMinor>,
kCYs2RHsMajor,
kCYs2RHsMinor>;
};
// ========================================================================
} // namespace