mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[rocm-libraries] ROCm/rocm-libraries#4495 (commit 5664eb0)
Adding layout test for amdgcn_mma structs ## Motivation Currently, the test suite for `amdgcn_mma` focuses on the design (e.g. choosing the correct specialization based on SFINAE) and a single live test that checks if selected MmaOp runs. This PR adds a simplified GEMM test kernel that checks the exact layout of the selected MmaOp. ## Technical Details The test in `test_amdgcn_mma_layout.cpp` launches MxKxN test cases (one per block), where each case: 1. Constructs A and B tensors on a device with a single 1 at A(m,k) and B(k,n) (rest is all 0s) 2. Executes the MMA intrinsic. 3. Checks if C has the "1" on the excpeted position. For the MMA instrinsic, it pulls a Mma op from amdgcn_mma specialization based on a given input (tile dimension, data types). Note 1: As a helper, in `test_amdgcn_mma_layout_util.hpp` we add register map for a given amdgcn_mma specialization. Register mapping is currently based on the `tile_distribution_encoding`. Note 2: Everything is added to the test suite, no additions to the actual `amdgcn_mma` structs. All the extra information that is needed, but not yet provided by `amdgcn_mma` structs, is added as a boilerplate to the header. TODO: Rebase this PR on top of the `amdgcn_mma` refactor or clean it up after merge. ## Test Plan This PR solely adds a new test to the existing code. ## Test Result Tests pass. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
b0c13f3124
commit
b80e41f3bc
@@ -17,3 +17,11 @@ if(GPU_TARGETS MATCHES "gfx9")
|
||||
else()
|
||||
message(DEBUG "Skipping ck_tile_gemm tests for current target")
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
|
||||
add_gtest_executable(test_amdgcn_mma_layout test_amdgcn_mma_layout.cpp)
|
||||
target_compile_options(test_amdgcn_mma_layout PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
else()
|
||||
message(DEBUG "Skipping gfx9|gfx11|gfx12 mma layout validation tests for current target")
|
||||
endif()
|
||||
|
||||
|
||||
286
test/ck_tile/core/arch/mma/test_amdgcn_mma_layout.cpp
Normal file
286
test/ck_tile/core/arch/mma/test_amdgcn_mma_layout.cpp
Normal file
@@ -0,0 +1,286 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
#include "ck_tile/host/stream_config.hpp"
|
||||
#include "ck_tile/host/device_memory.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/utility/env.hpp"
|
||||
|
||||
#include "test/ck_tile/core/arch/mma/test_amdgcn_mma_layout_util.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <vector>
|
||||
#include <cstddef>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
namespace ck = ck_tile;
|
||||
namespace mma = ck_tile::core::arch::mma;
|
||||
|
||||
// MMA register layout validation test for amdgcn_mma structs.
|
||||
//
|
||||
// Strategy: for every (m, k, n) triple in the tile, the test constructs a pair of input tensors
|
||||
// A and B that contain exactly one non-zero element each, placed so that their product
|
||||
// contributes to a single output element C(m, n):
|
||||
//
|
||||
// A (M x K) B (K x N) C = A * B (M x N)
|
||||
// . . . . . . . . . . . . . . . . . . . . . . . .
|
||||
// . . . . . . . . . . . . . . . . . . . . . . . .
|
||||
// . . . 1 . . . . . . . . . . . . . . . . . . . .
|
||||
// . . . . . . . . . . . 1 . . . . . . . . . 1 . .
|
||||
// . . . . . . . . . . . . . . . . . . . . . . . .
|
||||
// A(m,k) = 1 B(k,n) = 1 C(m,n) = 1
|
||||
//
|
||||
// The kernel uses RegisterMap to scatter A and B into the correct (lane, vecIdx) positions
|
||||
// of the MMA fragment registers, executes the intrinsic, then uses RegisterMap again to
|
||||
// gather back into C matrix. The position of "1" in C is checked against the expected (m, n)
|
||||
// location.
|
||||
|
||||
namespace {
|
||||
|
||||
/**
|
||||
* @class MmaLayoutTestKernel
|
||||
* @brief Device kernel that performs C = AB using a given Mma op
|
||||
*
|
||||
* @tparam ADataType Data type of tensor A elements
|
||||
* @tparam BDataType Data type of tensor B elements
|
||||
* @tparam CDataType Data type of tensor C elements
|
||||
* @tparam BlockM M-dimension of the MMA tile
|
||||
* @tparam BlockN N-dimension of the MMA tile
|
||||
* @tparam BlockK K-dimension of the MMA tile
|
||||
* @tparam BlockSize HIP block size
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t BlockM,
|
||||
uint32_t BlockN,
|
||||
uint32_t BlockK,
|
||||
uint32_t BlockSize>
|
||||
struct MmaLayoutTestKernel
|
||||
{
|
||||
static constexpr int kBlockSize = BlockSize;
|
||||
|
||||
__device__ void operator()(uint32_t* error_flags) const
|
||||
{
|
||||
using Selector =
|
||||
mma::MmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
BlockK,
|
||||
decltype(ck_tile::core::arch::get_compiler_target())>;
|
||||
using MmaOp = typename Selector::SelectedOp;
|
||||
using MmaTraits = mma::MmaOpTraits<MmaOp>;
|
||||
|
||||
if constexpr(MmaTraits::IsSupported)
|
||||
{
|
||||
using AVecType = typename MmaTraits::AVecType;
|
||||
using BVecType = typename MmaTraits::BVecType;
|
||||
using CVecType = typename MmaTraits::CVecType;
|
||||
constexpr uint32_t a_vec_size = vector_traits<AVecType>::vector_size;
|
||||
constexpr uint32_t b_vec_size = vector_traits<BVecType>::vector_size;
|
||||
constexpr uint32_t c_vec_size = vector_traits<CVecType>::vector_size;
|
||||
|
||||
const uint32_t lane = threadIdx.x;
|
||||
|
||||
AVecType a_frag{};
|
||||
BVecType b_frag{};
|
||||
CVecType c_frag{};
|
||||
|
||||
// get (m, k, n), where "1" should be placed for this block
|
||||
const uint32_t case_idx = static_cast<uint32_t>(blockIdx.x);
|
||||
const uint32_t m = case_idx / (MmaTraits::BlockK * MmaTraits::BlockN);
|
||||
const uint32_t k = (case_idx / MmaTraits::BlockN) % MmaTraits::BlockK;
|
||||
const uint32_t n = case_idx % MmaTraits::BlockN;
|
||||
|
||||
// place a single "1" in A/B fragments using (lane, vecIdx) -> (row, col) mapping
|
||||
for(uint32_t v = 0; v < a_vec_size; ++v)
|
||||
{
|
||||
auto a_coords = RegisterMap<MmaOp>::Register2AMap(lane, v);
|
||||
if(static_cast<uint32_t>(a_coords[0]) == m &&
|
||||
static_cast<uint32_t>(a_coords[1]) == k)
|
||||
{
|
||||
a_frag[v] = static_cast<ADataType>(1);
|
||||
}
|
||||
}
|
||||
|
||||
for(uint32_t v = 0; v < b_vec_size; ++v)
|
||||
{
|
||||
auto b_coords = RegisterMap<MmaOp>::Register2BMap(lane, v);
|
||||
if(static_cast<uint32_t>(b_coords[0]) == n &&
|
||||
static_cast<uint32_t>(b_coords[1]) == k)
|
||||
{
|
||||
b_frag[v] = static_cast<BDataType>(1);
|
||||
}
|
||||
}
|
||||
|
||||
c_frag = MmaOp::exec(a_frag, b_frag, c_frag);
|
||||
|
||||
uint32_t err = 0;
|
||||
const CDataType tol = static_cast<CDataType>(
|
||||
1.0e-1f); // TODO: this tolerance might not be suitable for all data types and
|
||||
// should be revisited if we add more configurations
|
||||
for(uint32_t v = 0; v < c_vec_size; ++v)
|
||||
{
|
||||
auto c_coords = RegisterMap<MmaOp>::Register2CMap(lane, v);
|
||||
const uint32_t i = static_cast<uint32_t>(c_coords[0]);
|
||||
const uint32_t j = static_cast<uint32_t>(c_coords[1]);
|
||||
|
||||
const CDataType expected =
|
||||
(i == m && j == n) ? static_cast<CDataType>(1) : static_cast<CDataType>(0);
|
||||
const CDataType value = static_cast<CDataType>(c_frag[v]);
|
||||
if(fabsf(static_cast<float>(value - expected)) > static_cast<float>(tol))
|
||||
{
|
||||
err = 1;
|
||||
}
|
||||
}
|
||||
|
||||
const uint32_t any_err = __any(err);
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
error_flags[case_idx] = any_err;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Test driver: runs the test for a given MMA configuration.
|
||||
*
|
||||
* The testlaunches (mkn) test cases (one per block) to check all possible positions of the "1" in
|
||||
* the A/B tensors.
|
||||
* 1. Constructs A and B tensors with a single 1 at A(m,k) and B(k,n).
|
||||
* 2. Executes MMA intrinsic to compute C tensor.
|
||||
* 3. Checks if C has the 1 in the expected position.
|
||||
*
|
||||
* @tparam Selector Selector for the Mma operation
|
||||
* @return true if the test ran on hardware; false if skipped (no device or unsupported)
|
||||
*/
|
||||
template <typename Selector>
|
||||
bool run_mma_layout_test()
|
||||
{
|
||||
using MmaOp = typename Selector::SelectedOp;
|
||||
using MmaTraits = mma::MmaOpTraits<MmaOp>;
|
||||
using ADataType = typename MmaTraits::ADataType;
|
||||
using BDataType = typename MmaTraits::BDataType;
|
||||
using CDataType = typename MmaTraits::CDataType;
|
||||
constexpr uint32_t BlockM = MmaTraits::BlockM;
|
||||
constexpr uint32_t BlockN = MmaTraits::BlockN;
|
||||
constexpr uint32_t BlockK = MmaTraits::BlockK;
|
||||
constexpr auto selector_target_id = MmaTraits::CompilerTarget::TARGET_ID;
|
||||
constexpr auto selector_wave_size = MmaTraits::CompilerTarget::WAVE_SIZE_ID;
|
||||
|
||||
int device_count = 0;
|
||||
hipDevice_t device{};
|
||||
HIP_CHECK_ERROR(hipGetDevice(&device));
|
||||
HIP_CHECK_ERROR(hipGetDeviceCount(&device_count));
|
||||
|
||||
hipDeviceProp_t props{};
|
||||
HIP_CHECK_ERROR(hipGetDeviceProperties(&props, device));
|
||||
|
||||
const auto runtime_target =
|
||||
ck_tile::core::arch::hip_device_prop_gcn_arch_name_to_amdgcn_target_id(props.gcnArchName);
|
||||
const bool has_device = device_count > 0;
|
||||
|
||||
if(!has_device || runtime_target == ck_tile::core::arch::amdgcn_target_id::HOST ||
|
||||
runtime_target != selector_target_id ||
|
||||
props.warpSize != static_cast<int>(selector_wave_size))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
constexpr uint32_t total_cases = BlockM * BlockK * BlockN;
|
||||
ck_tile::DeviceMem d_errors(total_cases * sizeof(uint32_t));
|
||||
std::vector<uint32_t> h_errors(total_cases, 0u);
|
||||
|
||||
auto* d_error_ptr = static_cast<uint32_t*>(d_errors.GetDeviceBuffer());
|
||||
|
||||
std::ignore = hipGetLastError();
|
||||
|
||||
using Kernel = MmaLayoutTestKernel<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
BlockK,
|
||||
static_cast<int>(selector_wave_size)>;
|
||||
|
||||
std::ignore =
|
||||
ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1},
|
||||
ck_tile::make_kernel(Kernel{},
|
||||
dim3(total_cases),
|
||||
dim3(static_cast<int>(selector_wave_size)),
|
||||
0,
|
||||
d_error_ptr));
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpyAsync(
|
||||
h_errors.data(), d_error_ptr, d_errors.GetBufferSize(), hipMemcpyDeviceToHost));
|
||||
HIP_CHECK_ERROR(hipStreamSynchronize(nullptr));
|
||||
|
||||
for(uint32_t case_idx = 0; case_idx < total_cases; ++case_idx)
|
||||
{
|
||||
const uint32_t m = case_idx / (BlockK * BlockN);
|
||||
const uint32_t k = (case_idx / BlockN) % BlockK;
|
||||
const uint32_t n = case_idx % BlockN;
|
||||
|
||||
EXPECT_EQ(h_errors[case_idx], 0u) << "Mismatch for m=" << m << " k=" << k << " n=" << n;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// ==================== Test configurations per target ====================
|
||||
// TODO: currently we have only 1 specific target per test. This should be revisited to enable all
|
||||
// the targets within the family (gfx12, gfx11, gfx9)
|
||||
using MmaGfx1201CompilerTarget = decltype(ck_tile::core::arch::make_amdgcn_gfx12_target<
|
||||
ck_tile::core::arch::amdgcn_target_id::GFX1201>());
|
||||
using MmaGfx90aCompilerTarget = decltype(ck_tile::core::arch::make_amdgcn_gfx9_target<
|
||||
ck_tile::core::arch::amdgcn_target_id::GFX90A>());
|
||||
using MmaGfx1100CompilerTarget = decltype(ck_tile::core::arch::make_amdgcn_gfx11_target<
|
||||
ck_tile::core::arch::amdgcn_target_id::GFX1100>());
|
||||
|
||||
using MmaGfx1201Selector = mma::
|
||||
MmaDefaultSelector<ck::fp16_t, ck::fp16_t, ck::fp32_t, 16u, 16u, 16u, MmaGfx1201CompilerTarget>;
|
||||
using MmaGfx90aSelector = mma::
|
||||
MmaDefaultSelector<ck::fp16_t, ck::fp16_t, ck::fp32_t, 16u, 16u, 16u, MmaGfx90aCompilerTarget>;
|
||||
using MmaGfx1100Selector = mma::
|
||||
MmaDefaultSelector<ck::fp16_t, ck::fp16_t, ck::fp32_t, 16u, 16u, 16u, MmaGfx1100CompilerTarget>;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
MmaGfx1201Selector,
|
||||
MmaGfx90aSelector,
|
||||
MmaGfx1100Selector
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
template <typename Selector>
|
||||
class TestMmaLayout : public ::testing::Test
|
||||
{
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(TestMmaLayout, KernelTypes);
|
||||
|
||||
TYPED_TEST(TestMmaLayout, Mma_16x16x16_F16_F16_F32)
|
||||
{
|
||||
bool executed = run_mma_layout_test<TypeParam>();
|
||||
|
||||
if(!executed)
|
||||
{
|
||||
GTEST_SKIP() << "No supported HIP device found. Skipping test.";
|
||||
}
|
||||
}
|
||||
292
test/ck_tile/core/arch/mma/test_amdgcn_mma_layout_util.hpp
Normal file
292
test/ck_tile/core/arch/mma/test_amdgcn_mma_layout_util.hpp
Normal file
@@ -0,0 +1,292 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_traits.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_selector.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp"
|
||||
|
||||
#include <cassert>
|
||||
#include <cstdint>
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace ck_tile;
|
||||
|
||||
/**
|
||||
* @class RegisterMapTraits
|
||||
* @brief Traits class that defines tile_distribution_encoding for each MmaOp
|
||||
* @tparam MmaOp amdgcn_mma specialization
|
||||
*/
|
||||
template <typename MmaOp>
|
||||
struct RegisterMapTraits
|
||||
{
|
||||
static_assert(sizeof(MmaOp) == 0, "RegisterMapTraits requires a specialization");
|
||||
};
|
||||
|
||||
/**
|
||||
* @class RegisterMap
|
||||
* @brief Uses specialized RegisterMapTraits to get the encoding
|
||||
* @tparam MmaOp amdgcn_mma specialization
|
||||
*/
|
||||
template <typename MmaOp>
|
||||
struct RegisterMap
|
||||
{
|
||||
using Traits = RegisterMapTraits<MmaOp>;
|
||||
|
||||
using AMap = core::arch::mma::TileDistrEncRegMap<typename Traits::AWarpDstrEncoding>;
|
||||
using BMap = core::arch::mma::TileDistrEncRegMap<typename Traits::BWarpDstrEncoding>;
|
||||
using CMap = core::arch::mma::TileDistrEncRegMap<typename Traits::CWarpDstrEncoding>;
|
||||
|
||||
CK_TILE_HOST_DEVICE static auto Register2AMap(const uint32_t lane, const uint32_t vecIdx)
|
||||
{
|
||||
return AMap::calc_matrix_indices_from_lane_vector(static_cast<index_t>(lane),
|
||||
static_cast<index_t>(vecIdx));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static auto Register2BMap(const uint32_t lane, const uint32_t vecIdx)
|
||||
{
|
||||
return BMap::calc_matrix_indices_from_lane_vector(static_cast<index_t>(lane),
|
||||
static_cast<index_t>(vecIdx));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static auto Register2CMap(const uint32_t lane, const uint32_t vecIdx)
|
||||
{
|
||||
return CMap::calc_matrix_indices_from_lane_vector(static_cast<index_t>(lane),
|
||||
static_cast<index_t>(vecIdx));
|
||||
}
|
||||
};
|
||||
|
||||
// ====================== Specializations per target =====================
|
||||
|
||||
/**
|
||||
* @brief RegisterMapTraits for GFX12 WMMA 16x16x16_F16_F16_F32_GFX12
|
||||
*/
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
struct RegisterMapTraits<ck_tile::core::arch::mma::amdgcn_mma<
|
||||
ck_tile::fp16_t,
|
||||
ck_tile::fp16_t,
|
||||
ck_tile::fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
ck_tile::core::arch::enable_if_target_family_gfx12_t<CompilerTarget>>>
|
||||
{
|
||||
using MmaOp = ck_tile::core::arch::mma::amdgcn_mma<ck_tile::fp16_t,
|
||||
ck_tile::fp16_t,
|
||||
ck_tile::fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
CtrlFlags,
|
||||
CompilerTarget>;
|
||||
|
||||
using MmaTraits = ck_tile::core::arch::mma::MmaOpTraits<MmaOp>;
|
||||
static constexpr index_t WaveSize =
|
||||
static_cast<index_t>(MmaTraits::CompilerTarget::WAVE_SIZE_ID);
|
||||
static constexpr index_t AVecSize = vector_traits<typename MmaTraits::AVecType>::vector_size;
|
||||
static constexpr index_t BVecSize = vector_traits<typename MmaTraits::BVecType>::vector_size;
|
||||
static constexpr index_t CVecSize = vector_traits<typename MmaTraits::CVecType>::vector_size;
|
||||
|
||||
using kABPs2RHssMajor = sequence<2, 1>;
|
||||
using kABPs2RHssMinor = sequence<1, 0>;
|
||||
using kABYs2RHsMajor = sequence<2, 2>;
|
||||
using kABYs2RHsMinor = sequence<0, 2>;
|
||||
using kCPs2RHssMajor = sequence<1, 2>;
|
||||
using kCPs2RHssMinor = sequence<1, 0>;
|
||||
using kCYs2RHsMajor = sequence<1, 1>;
|
||||
using kCYs2RHsMinor = sequence<0, 2>;
|
||||
|
||||
// TODO: remove these and fix constants in amdgcn_mma
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABK0PerLane = 1;
|
||||
static constexpr index_t kABKLane = 2;
|
||||
static constexpr index_t kABK1PerLane = 8;
|
||||
static constexpr index_t kCMLane = 2;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
static constexpr index_t kCM0PerLane = 1;
|
||||
static constexpr index_t kCM1PerLane = 8;
|
||||
|
||||
using AWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<kAMLane>, sequence<kABK0PerLane, kABKLane, kABK1PerLane>>, // <16>, <1, 2, 8>
|
||||
tuple<kABPs2RHssMajor>,
|
||||
tuple<kABPs2RHssMinor>,
|
||||
kABYs2RHsMajor,
|
||||
kABYs2RHsMinor>;
|
||||
|
||||
using BWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<kBNLane>, sequence<kABK0PerLane, kABKLane, kABK1PerLane>>, // <16>, <1, 2, 8>
|
||||
tuple<kABPs2RHssMajor>,
|
||||
tuple<kABPs2RHssMinor>,
|
||||
kABYs2RHsMajor,
|
||||
kABYs2RHsMinor>;
|
||||
|
||||
using CWarpDstrEncoding =
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<kCM0PerLane, kCMLane, kCM1PerLane>,
|
||||
sequence<kCNLane>>, // <1, 2, 8>, <16>
|
||||
tuple<kCPs2RHssMajor>,
|
||||
tuple<kCPs2RHssMinor>,
|
||||
kCYs2RHsMajor,
|
||||
kCYs2RHsMinor>;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief RegisterMapTraits for GFX9 MFMA 16x16x16_F16_F16_F32_GFX9
|
||||
*/
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
struct RegisterMapTraits<ck_tile::core::arch::mma::amdgcn_mma<
|
||||
ck_tile::fp16_t,
|
||||
ck_tile::fp16_t,
|
||||
ck_tile::fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
ck_tile::core::arch::enable_if_target_family_gfx9_t<CompilerTarget>>>
|
||||
{
|
||||
using MmaOp = ck_tile::core::arch::mma::amdgcn_mma<ck_tile::fp16_t,
|
||||
ck_tile::fp16_t,
|
||||
ck_tile::fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
CtrlFlags,
|
||||
CompilerTarget>;
|
||||
|
||||
using MmaTraits = ck_tile::core::arch::mma::MmaOpTraits<MmaOp>;
|
||||
static constexpr index_t WaveSize =
|
||||
static_cast<index_t>(MmaTraits::CompilerTarget::WAVE_SIZE_ID);
|
||||
static constexpr index_t AVecSize = vector_traits<typename MmaTraits::AVecType>::vector_size;
|
||||
static constexpr index_t BVecSize = vector_traits<typename MmaTraits::BVecType>::vector_size;
|
||||
static constexpr index_t CVecSize = vector_traits<typename MmaTraits::CVecType>::vector_size;
|
||||
|
||||
using kABPs2RHssMajor = sequence<2, 1>;
|
||||
using kABPs2RHssMinor = sequence<0, 0>;
|
||||
using kABYs2RHsMajor = sequence<2>;
|
||||
using kABYs2RHsMinor = sequence<1>;
|
||||
using kCPs2RHssMajor = sequence<1, 2>;
|
||||
using kCPs2RHssMinor = sequence<0, 0>;
|
||||
using kCYs2RHsMajor = sequence<1>;
|
||||
using kCYs2RHsMinor = sequence<1>;
|
||||
|
||||
using AWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<MmaOp::kAMLane>, sequence<MmaOp::kABKLane, MmaOp::kABKPerLane>>,
|
||||
tuple<kABPs2RHssMajor>,
|
||||
tuple<kABPs2RHssMinor>,
|
||||
kABYs2RHsMajor,
|
||||
kABYs2RHsMinor>;
|
||||
|
||||
using BWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<MmaOp::kBNLane>, sequence<MmaOp::kABKLane, MmaOp::kABKPerLane>>,
|
||||
tuple<kABPs2RHssMajor>,
|
||||
tuple<kABPs2RHssMinor>,
|
||||
kABYs2RHsMajor,
|
||||
kABYs2RHsMinor>;
|
||||
|
||||
using CWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<MmaOp::kCMLane, MmaOp::kCM1PerLane>, sequence<MmaOp::kCNLane>>,
|
||||
tuple<kCPs2RHssMajor>,
|
||||
tuple<kCPs2RHssMinor>,
|
||||
kCYs2RHsMajor,
|
||||
kCYs2RHsMinor>;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief RegisterMapTraits for GFX11 WMMA 16x16x16_F16_F16_F32_GFX11
|
||||
*/
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
struct RegisterMapTraits<ck_tile::core::arch::mma::amdgcn_mma<
|
||||
ck_tile::fp16_t,
|
||||
ck_tile::fp16_t,
|
||||
ck_tile::fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
ck_tile::core::arch::enable_if_target_family_gfx11_t<CompilerTarget>>>
|
||||
{
|
||||
using MmaOp = ck_tile::core::arch::mma::amdgcn_mma<ck_tile::fp16_t,
|
||||
ck_tile::fp16_t,
|
||||
ck_tile::fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
CtrlFlags,
|
||||
CompilerTarget>;
|
||||
|
||||
using MmaTraits = ck_tile::core::arch::mma::MmaOpTraits<MmaOp>;
|
||||
static constexpr index_t WaveSize =
|
||||
static_cast<index_t>(MmaTraits::CompilerTarget::WAVE_SIZE_ID);
|
||||
static constexpr index_t AVecSize = vector_traits<typename MmaTraits::AVecType>::vector_size;
|
||||
static constexpr index_t BVecSize = vector_traits<typename MmaTraits::BVecType>::vector_size;
|
||||
static constexpr index_t CVecSize = vector_traits<typename MmaTraits::CVecType>::vector_size;
|
||||
|
||||
using kABPs2RHssMajor = sequence<0, 1>;
|
||||
using kABPs2RHssMinor = sequence<0, 0>;
|
||||
using kABYs2RHsMajor = sequence<2>;
|
||||
using kABYs2RHsMinor = sequence<0>;
|
||||
using kCPs2RHssMajor = sequence<1, 2>;
|
||||
using kCPs2RHssMinor = sequence<1, 0>;
|
||||
using kCYs2RHsMajor = sequence<1>;
|
||||
using kCYs2RHsMinor = sequence<0>;
|
||||
|
||||
// TODO: remove these and fix constants in amdgcn_mma
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABK0PerLane = 1;
|
||||
static constexpr index_t kABKLane = 1;
|
||||
static constexpr index_t kABK1PerLane = 16;
|
||||
static constexpr index_t kCMLane = 2;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
static constexpr index_t kCM0PerLane = 8;
|
||||
static constexpr index_t kCM1PerLane = 1;
|
||||
|
||||
using AWarpDstrEncoding =
|
||||
tile_distribution_encoding<sequence<2>, // kRepeat
|
||||
tuple<sequence<kAMLane>, sequence<kABK1PerLane>>,
|
||||
tuple<kABPs2RHssMajor>,
|
||||
tuple<kABPs2RHssMinor>,
|
||||
kABYs2RHsMajor,
|
||||
kABYs2RHsMinor>;
|
||||
|
||||
using BWarpDstrEncoding =
|
||||
tile_distribution_encoding<sequence<2>, // kRepeat
|
||||
tuple<sequence<kBNLane>, sequence<kABK1PerLane>>,
|
||||
tuple<kABPs2RHssMajor>,
|
||||
tuple<kABPs2RHssMinor>,
|
||||
kABYs2RHsMajor,
|
||||
kABYs2RHsMinor>;
|
||||
|
||||
using CWarpDstrEncoding =
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<kCM0PerLane, kCMLane>, sequence<kCNLane>>,
|
||||
tuple<kCPs2RHssMajor>,
|
||||
tuple<kCPs2RHssMinor>,
|
||||
kCYs2RHsMajor,
|
||||
kCYs2RHsMinor>;
|
||||
};
|
||||
|
||||
// ========================================================================
|
||||
|
||||
} // namespace
|
||||
Reference in New Issue
Block a user