mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
[CK Tile] Add Tile Distribution Encoding Calculator (#5515)
## Motivation We want to be able to calculate TileDistributionEncodings describing register mappings for any MmaOp. This is necessary for further integration with CK Tile. This MR adds a new struct TileDistrEncCalc, which takes an amdgcn_mma type (MmaOp) and provides ABC warp distribution encodings for mapping matrix fragment coordinates to register coordinates (lane, vector item) and vice versa. It is able to take CTranpose, Swizzle, and NumAccessA / NumAccessB template parameters for tweaking the tile distributions. Swizzle modification will be implemented later. The current implementation can deal with all intrinsic types and block-hiding. This MR also adds some additional static asserts and derived params within amdgcn_mma_base, to enforce consistency and help calculate Tile Distributions for block-hiding intrinsics. An Example was added that uses the Tile Distr Enc Calc to calc and print register layouts for Tile Distributions for some of our amdgcn_mma structs. It also makes sure that the CTranspose modifier works as intended. Some additional gfx9 intrinsics were added to test block-hiding layouts for the different types of C-block-hiding layouts. The sparse intrinsic wrappers were updated according to Chris's recent changes in another branch (https://github.com/ROCm/rocm-libraries/pull/5508), which moved the compression step outside of the intrinsic itself. This is necessary to make sure that the Calculator can deal with this new interpretation of the sparse intrinsics. I directly copied the new amdgcn structs from Chris's branch and changed nothing else to avoid more complex merges in the future. Note that this means I did not update a bunch of related sparse code since that would be a lot, and therefore I disabled test_amdgcn_sparse_mma for now. The amdgcn_mma_layout test was refactored a bit: - The old register mapping utility was removed and its use was replaced by the new TileDistrEncCalc - More tests were added to test layouts for different types of block-hiding and sparse intrinsics - The Selector method was removed and the tests were split up over target architectures, with each target arch having a direct list of amdgcn structs to be tested. This ensures that we force specific tests on specific architectures and makes sure that the selector doesn't quietly do some workarounds like creating compound intrinsics. ## Test Results Layout tests based on calculated tile distribution encodings pass on all architectures. Calculator works for all currently added amdgcn structs, which includes different types of block-hiding and sparse intrinsics. Printed layouts from new example verified by eye. CTranspose modifier tested for large set of intrinsics.
This commit is contained in:
committed by
GitHub
parent
160bc1363e
commit
6cd016dde4
@@ -7,10 +7,11 @@ if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9|gfx12")
|
||||
add_gtest_executable(test_amdgcn_sparse_mma test_amdgcn_sparse_mma.cpp)
|
||||
target_compile_options(test_amdgcn_sparse_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
# TODO: This test is temporarily disabled for cooperation / work planning reasons. Re-enable after merging related work.
|
||||
# if(GPU_TARGETS MATCHES "gfx9|gfx12")
|
||||
# add_gtest_executable(test_amdgcn_sparse_mma test_amdgcn_sparse_mma.cpp)
|
||||
# target_compile_options(test_amdgcn_sparse_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
# endif()
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_gtest_executable(test_amdgcn_mma test_amdgcn_mma.cpp)
|
||||
target_compile_options(test_amdgcn_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
@@ -18,10 +19,28 @@ else()
|
||||
message(DEBUG "Skipping ck_tile_gemm tests for current target")
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
|
||||
add_gtest_executable(test_amdgcn_mma_layout test_amdgcn_mma_layout.cpp)
|
||||
target_compile_options(test_amdgcn_mma_layout PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
else()
|
||||
message(DEBUG "Skipping gfx9|gfx11|gfx12 mma layout validation tests for current target")
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_gtest_executable(test_amdgcn_mma_layout_gfx9 test_amdgcn_mma_layout_gfx9.cpp)
|
||||
target_compile_options(test_amdgcn_mma_layout_gfx9 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx942|gfx950")
|
||||
add_gtest_executable(test_amdgcn_mma_layout_gfx942 test_amdgcn_mma_layout_gfx942.cpp)
|
||||
target_compile_options(test_amdgcn_mma_layout_gfx942 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx950")
|
||||
add_gtest_executable(test_amdgcn_mma_layout_gfx950 test_amdgcn_mma_layout_gfx950.cpp)
|
||||
target_compile_options(test_amdgcn_mma_layout_gfx950 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx11")
|
||||
add_gtest_executable(test_amdgcn_mma_layout_gfx11 test_amdgcn_mma_layout_gfx11.cpp)
|
||||
target_compile_options(test_amdgcn_mma_layout_gfx11 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx12")
|
||||
add_gtest_executable(test_amdgcn_mma_layout_gfx12 test_amdgcn_mma_layout_gfx12.cpp)
|
||||
target_compile_options(test_amdgcn_mma_layout_gfx12 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
|
||||
|
||||
@@ -1,304 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
#include "ck_tile/host/stream_config.hpp"
|
||||
#include "ck_tile/host/device_memory.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/utility/env.hpp"
|
||||
|
||||
#include "test/ck_tile/core/arch/mma/test_amdgcn_mma_layout_util.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <vector>
|
||||
#include <cstddef>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
namespace ck = ck_tile;
|
||||
namespace mma = ck_tile::core::arch::mma;
|
||||
|
||||
// MMA register layout validation test for amdgcn_mma structs.
|
||||
//
|
||||
// Strategy: for every (m, k, n) triple in the tile, the test constructs a pair of input tensors
|
||||
// A and B that contain exactly one non-zero element each, placed so that their product
|
||||
// contributes to a single output element C(m, n):
|
||||
//
|
||||
// A (M x K) B (K x N) C = A * B (M x N)
|
||||
// . . . . . . . . . . . . . . . . . . . . . . . .
|
||||
// . . . . . . . . . . . . . . . . . . . . . . . .
|
||||
// . . . 1 . . . . . . . . . . . . . . . . . . . .
|
||||
// . . . . . . . . . . . 1 . . . . . . . . . 1 . .
|
||||
// . . . . . . . . . . . . . . . . . . . . . . . .
|
||||
// A(m,k) = 1 B(k,n) = 1 C(m,n) = 1
|
||||
//
|
||||
// The kernel uses RegisterMap to scatter A and B into the correct (lane, vecIdx) positions
|
||||
// of the MMA fragment registers, executes the intrinsic, then uses RegisterMap again to
|
||||
// gather back into C matrix. The position of "1" in C is checked against the expected (m, n)
|
||||
// location.
|
||||
|
||||
namespace {
|
||||
|
||||
/**
|
||||
* @class MmaLayoutTestKernel
|
||||
* @brief Device kernel that performs C = AB using a given Mma op
|
||||
*
|
||||
* @tparam ADataType Data type of tensor A elements
|
||||
* @tparam BDataType Data type of tensor B elements
|
||||
* @tparam CDataType Data type of tensor C elements
|
||||
* @tparam FragM M-dimension of the MMA tile
|
||||
* @tparam FragN N-dimension of the MMA tile
|
||||
* @tparam FragK K-dimension of the MMA tile
|
||||
* @tparam BlockSize HIP block size
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t FragM,
|
||||
uint32_t FragN,
|
||||
uint32_t FragK,
|
||||
uint32_t BlockSize>
|
||||
struct MmaLayoutTestKernel
|
||||
{
|
||||
static constexpr int kBlockSize = BlockSize;
|
||||
|
||||
__device__ void operator()(uint32_t* error_flags) const
|
||||
{
|
||||
using Selector =
|
||||
mma::MmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
FragM,
|
||||
FragN,
|
||||
FragK,
|
||||
decltype(ck_tile::core::arch::get_compiler_target()),
|
||||
mma::MmaOpFamily::DENSE>;
|
||||
using MmaOp = typename Selector::SelectedOp;
|
||||
|
||||
if constexpr(mma::MmaOpTraits<MmaOp>::IsSupported)
|
||||
{
|
||||
using AVecType = typename MmaOp::AVecType;
|
||||
using BVecType = typename MmaOp::BVecType;
|
||||
using CVecType = typename MmaOp::CVecType;
|
||||
constexpr uint32_t a_vec_size = vector_traits<AVecType>::vector_size;
|
||||
constexpr uint32_t b_vec_size = vector_traits<BVecType>::vector_size;
|
||||
constexpr uint32_t c_vec_size = vector_traits<CVecType>::vector_size;
|
||||
|
||||
const uint32_t lane = threadIdx.x;
|
||||
|
||||
AVecType a_frag{};
|
||||
BVecType b_frag{};
|
||||
CVecType c_frag{};
|
||||
|
||||
// get (m, k, n), where "1" should be placed for this block
|
||||
const uint32_t case_idx = static_cast<uint32_t>(blockIdx.x);
|
||||
const uint32_t m = case_idx / (MmaOp::kK * MmaOp::kN);
|
||||
const uint32_t k = (case_idx / MmaOp::kN) % MmaOp::kK;
|
||||
const uint32_t n = case_idx % MmaOp::kN;
|
||||
|
||||
// place a single "1" in A/B fragments using (lane, vecIdx) -> (row, col) mapping
|
||||
for(uint32_t v = 0; v < a_vec_size; ++v)
|
||||
{
|
||||
auto a_coords = RegisterMap<MmaOp>::Register2AMap(lane, v);
|
||||
if(static_cast<uint32_t>(a_coords[0]) == m &&
|
||||
static_cast<uint32_t>(a_coords[1]) == k)
|
||||
{
|
||||
a_frag[v] = static_cast<ADataType>(1);
|
||||
}
|
||||
}
|
||||
|
||||
for(uint32_t v = 0; v < b_vec_size; ++v)
|
||||
{
|
||||
auto b_coords = RegisterMap<MmaOp>::Register2BMap(lane, v);
|
||||
if(static_cast<uint32_t>(b_coords[0]) == n &&
|
||||
static_cast<uint32_t>(b_coords[1]) == k)
|
||||
{
|
||||
b_frag[v] = static_cast<BDataType>(1);
|
||||
}
|
||||
}
|
||||
|
||||
c_frag = MmaOp::exec(a_frag, b_frag, c_frag);
|
||||
|
||||
uint32_t err = 0;
|
||||
const CDataType tol = static_cast<CDataType>(
|
||||
1.0e-1f); // TODO: this tolerance might not be suitable for all data types and
|
||||
// should be revisited if we add more configurations
|
||||
for(uint32_t v = 0; v < c_vec_size; ++v)
|
||||
{
|
||||
auto c_coords = RegisterMap<MmaOp>::Register2CMap(lane, v);
|
||||
const uint32_t i = static_cast<uint32_t>(c_coords[0]);
|
||||
const uint32_t j = static_cast<uint32_t>(c_coords[1]);
|
||||
|
||||
const CDataType expected =
|
||||
(i == m && j == n) ? static_cast<CDataType>(1) : static_cast<CDataType>(0);
|
||||
const CDataType value = static_cast<CDataType>(c_frag[v]);
|
||||
if(fabsf(static_cast<float>(value - expected)) > static_cast<float>(tol))
|
||||
{
|
||||
err = 1;
|
||||
}
|
||||
}
|
||||
|
||||
const uint32_t any_err = __any(err);
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
error_flags[case_idx] = any_err;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Test driver: runs the test for a given MMA configuration.
|
||||
*
|
||||
* The testlaunches (mkn) test cases (one per block) to check all possible positions of the "1" in
|
||||
* the A/B tensors.
|
||||
* 1. Constructs A and B tensors with a single 1 at A(m,k) and B(k,n).
|
||||
* 2. Executes MMA intrinsic to compute C tensor.
|
||||
* 3. Checks if C has the 1 in the expected position.
|
||||
*
|
||||
* @tparam Selector Selector for the Mma operation
|
||||
* @return true if the test ran on hardware; false if skipped (no device or unsupported)
|
||||
*/
|
||||
template <typename Selector>
|
||||
bool run_mma_layout_test()
|
||||
{
|
||||
using MmaOp = typename Selector::SelectedOp;
|
||||
using MmaTraits = mma::MmaOpTraits<MmaOp>;
|
||||
using ADataType = typename MmaOp::ADataType;
|
||||
using BDataType = typename MmaOp::BDataType;
|
||||
using CDataType = typename MmaOp::CDataType;
|
||||
constexpr uint32_t FragM = MmaOp::kM;
|
||||
constexpr uint32_t FragN = MmaOp::kN;
|
||||
constexpr uint32_t FragK = MmaOp::kK;
|
||||
constexpr auto selector_target_id = MmaTraits::CompilerTarget::TARGET_ID;
|
||||
constexpr auto selector_wave_size = MmaTraits::CompilerTarget::WAVE_SIZE_ID;
|
||||
|
||||
int device_count = 0;
|
||||
hipDevice_t device{};
|
||||
HIP_CHECK_ERROR(hipGetDevice(&device));
|
||||
HIP_CHECK_ERROR(hipGetDeviceCount(&device_count));
|
||||
|
||||
hipDeviceProp_t props{};
|
||||
HIP_CHECK_ERROR(hipGetDeviceProperties(&props, device));
|
||||
|
||||
const auto runtime_target =
|
||||
ck_tile::core::arch::hip_device_prop_gcn_arch_name_to_amdgcn_target_id(props.gcnArchName);
|
||||
const bool has_device = device_count > 0;
|
||||
|
||||
if(!has_device || runtime_target == ck_tile::core::arch::amdgcn_target_id::HOST ||
|
||||
runtime_target != selector_target_id ||
|
||||
props.warpSize != static_cast<int>(selector_wave_size))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
constexpr uint32_t total_cases = FragM * FragK * FragN;
|
||||
ck_tile::DeviceMem d_errors(total_cases * sizeof(uint32_t));
|
||||
std::vector<uint32_t> h_errors(total_cases, 0u);
|
||||
|
||||
auto* d_error_ptr = static_cast<uint32_t*>(d_errors.GetDeviceBuffer());
|
||||
|
||||
std::ignore = hipGetLastError();
|
||||
|
||||
using Kernel = MmaLayoutTestKernel<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
FragM,
|
||||
FragN,
|
||||
FragK,
|
||||
static_cast<int>(selector_wave_size)>;
|
||||
|
||||
std::ignore =
|
||||
ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1},
|
||||
ck_tile::make_kernel(Kernel{},
|
||||
dim3(total_cases),
|
||||
dim3(static_cast<int>(selector_wave_size)),
|
||||
0,
|
||||
d_error_ptr));
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpyAsync(
|
||||
h_errors.data(), d_error_ptr, d_errors.GetBufferSize(), hipMemcpyDeviceToHost));
|
||||
HIP_CHECK_ERROR(hipStreamSynchronize(nullptr));
|
||||
|
||||
for(uint32_t case_idx = 0; case_idx < total_cases; ++case_idx)
|
||||
{
|
||||
const uint32_t m = case_idx / (FragK * FragN);
|
||||
const uint32_t k = (case_idx / FragN) % FragK;
|
||||
const uint32_t n = case_idx % FragN;
|
||||
|
||||
EXPECT_EQ(h_errors[case_idx], 0u) << "Mismatch for m=" << m << " k=" << k << " n=" << n;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// ==================== Test configurations per target ====================
|
||||
// TODO: currently we have only 1 specific target per test. This should be revisited to enable all
|
||||
// the targets within the family (gfx12, gfx11, gfx9)
|
||||
using MmaGfx1201CompilerTarget = decltype(ck_tile::core::arch::make_amdgcn_gfx12_target<
|
||||
ck_tile::core::arch::amdgcn_target_id::GFX1201>());
|
||||
using MmaGfx90aCompilerTarget = decltype(ck_tile::core::arch::make_amdgcn_gfx9_target<
|
||||
ck_tile::core::arch::amdgcn_target_id::GFX90A>());
|
||||
using MmaGfx1100CompilerTarget = decltype(ck_tile::core::arch::make_amdgcn_gfx11_target<
|
||||
ck_tile::core::arch::amdgcn_target_id::GFX1100>());
|
||||
|
||||
using MmaGfx1201Selector = mma::MmaDefaultSelector<ck::fp16_t,
|
||||
ck::fp16_t,
|
||||
ck::fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
MmaGfx1201CompilerTarget,
|
||||
mma::MmaOpFamily::DENSE>;
|
||||
using MmaGfx90aSelector = mma::MmaDefaultSelector<ck::fp16_t,
|
||||
ck::fp16_t,
|
||||
ck::fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
MmaGfx90aCompilerTarget,
|
||||
mma::MmaOpFamily::DENSE>;
|
||||
using MmaGfx1100Selector = mma::MmaDefaultSelector<ck::fp16_t,
|
||||
ck::fp16_t,
|
||||
ck::fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
MmaGfx1100CompilerTarget,
|
||||
mma::MmaOpFamily::DENSE>;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
MmaGfx1201Selector,
|
||||
MmaGfx90aSelector,
|
||||
MmaGfx1100Selector
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
template <typename Selector>
|
||||
class TestMmaLayout : public ::testing::Test
|
||||
{
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(TestMmaLayout, KernelTypes);
|
||||
|
||||
TYPED_TEST(TestMmaLayout, Mma_16x16x16_F16_F16_F32)
|
||||
{
|
||||
bool executed = run_mma_layout_test<TypeParam>();
|
||||
|
||||
if(!executed)
|
||||
{
|
||||
GTEST_SKIP() << "No supported HIP device found. Skipping test.";
|
||||
}
|
||||
}
|
||||
239
test/ck_tile/core/arch/mma/test_amdgcn_mma_layout.inc
Normal file
239
test/ck_tile/core/arch/mma/test_amdgcn_mma_layout.inc
Normal file
@@ -0,0 +1,239 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
#include "ck_tile/host/stream_config.hpp"
|
||||
#include "ck_tile/host/device_memory.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
|
||||
#include <cstdio>
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace ck_tile;
|
||||
using namespace ck_tile::core::arch;
|
||||
using namespace mma;
|
||||
|
||||
using F16 = fp16_t;
|
||||
using F32 = fp32_t;
|
||||
using Target908 = decltype(make_amdgcn_gfx9_target<amdgcn_target_id::GFX908>());
|
||||
using Target942 = decltype(make_amdgcn_gfx9_target<amdgcn_target_id::GFX942>());
|
||||
using Target950 = decltype(make_amdgcn_gfx9_target<amdgcn_target_id::GFX950>());
|
||||
using Target11 = decltype(make_amdgcn_gfx11_target<amdgcn_target_id::GFX1100>());
|
||||
using Target12 = decltype(make_amdgcn_gfx12_target<amdgcn_target_id::GFX1201>());
|
||||
|
||||
// MMA register layout validation test for amdgcn_mma structs.
|
||||
//
|
||||
// Strategy: for every (m, k, n) triple in the tile, the test constructs a pair of input tensors A
|
||||
// and B that contain exactly one non-zero element each, placed so that their product contributes to
|
||||
// a single output element C(m, n):
|
||||
//
|
||||
// A (M x K) B (K x N) C = A * B (M x N)
|
||||
// . . . . . . . . . . . . . . . . . . . . . . . .
|
||||
// . . . . . . . . . . . . . . . . . . . . . . . .
|
||||
// . . . 1 . . . . . . . . . . . . . . . . . . . .
|
||||
// . . . . . . . . . . . 1 . . . . . . . . . 1 . .
|
||||
// . . . . . . . . . . . . . . . . . . . . . . . .
|
||||
// A(m,k) = 1 B(k,n) = 1 C(m,n) = 1
|
||||
//
|
||||
// The kernel uses TileDistrEncRegMap to scatter A and B into the correct (lane, vecIdx) positions
|
||||
// of the MMA fragment registers, executes the intrinsic, then uses TileDistrEncRegMap again to
|
||||
// gather back into C matrix. The position of "1" in C is checked against the expected (m, n)
|
||||
// location.
|
||||
|
||||
/**
|
||||
* @class MmaLayoutTestKernel
|
||||
* @brief Device kernel that performs C = AB using a given Mma op
|
||||
* @tparam MmaOp Intrinsic (amdgcn_mma) to be tested
|
||||
*/
|
||||
template <typename MmaOp> // TODO: C++20 concept for MmaOp
|
||||
struct MmaLayoutTestKernel
|
||||
{
|
||||
static constexpr int kBlockSize = MmaOp::WaveSize;
|
||||
|
||||
__device__ void operator()(uint32_t* error_flags) const
|
||||
{
|
||||
using ARegMap = TileDistrEncRegMap<typename TileDistrEncCalc<MmaOp>::AWarpDstrEncoding>;
|
||||
using BRegMap = TileDistrEncRegMap<typename TileDistrEncCalc<MmaOp>::BWarpDstrEncoding>;
|
||||
using CRegMap = TileDistrEncRegMap<typename TileDistrEncCalc<MmaOp>::CWarpDstrEncoding>;
|
||||
|
||||
if constexpr(MmaOpTraits<MmaOp>::IsSupported)
|
||||
{
|
||||
using AVecType = typename MmaOp::AVecType;
|
||||
using BVecType = typename MmaOp::BVecType;
|
||||
using CVecType = typename MmaOp::CVecType;
|
||||
constexpr index_t a_vec_size = vector_traits<AVecType>::vector_size;
|
||||
constexpr index_t b_vec_size = vector_traits<BVecType>::vector_size;
|
||||
constexpr index_t c_vec_size = vector_traits<CVecType>::vector_size;
|
||||
|
||||
const index_t lane = threadIdx.x;
|
||||
|
||||
AVecType a_frag{};
|
||||
BVecType b_frag{};
|
||||
CVecType c_frag{};
|
||||
uint32_t sparse_idx{};
|
||||
static_assert(MmaOp::kCompressionRatio <= 2); // Allow only 4:2 compression (or no).
|
||||
|
||||
// get (m, k, n), where "1" should be placed for this block
|
||||
const index_t case_idx = blockIdx.x;
|
||||
const index_t m = case_idx / (MmaOp::kK * MmaOp::kN);
|
||||
const index_t k = (case_idx / MmaOp::kN) % MmaOp::kK;
|
||||
const index_t n = case_idx % MmaOp::kN;
|
||||
|
||||
// place a single "1" in A/B fragments using (lane, vecIdx) -> (row, col) mapping
|
||||
for(index_t v = 0; v < a_vec_size; ++v)
|
||||
{
|
||||
auto a_coords = ARegMap::calc_matrix_indices_from_lane_vector(lane, v);
|
||||
|
||||
// When dealing with sparse intrinsics, the A matrix is compressed in the K
|
||||
// direction and we just put our "1" in the k / 2 position (rounded down).
|
||||
if(a_coords[0] == m && a_coords[1] == (k / MmaOp::kCompressionRatio))
|
||||
{
|
||||
a_frag[v] = 1;
|
||||
|
||||
// Calc an appropriate sparse idx value for a single 1 in position k. We use a
|
||||
// baseline index of 0x88888888. This sends each compressed index i to
|
||||
// uncompressed index i * 2. If k is odd, we should send it to i * 2 + 1
|
||||
// instead. We update only the absolutely necessary pair of bits for this
|
||||
// (idx[v*2:v*2+1]). Note that this simple calculation works for any 4:2 sparse
|
||||
// intrinsic with up to 16 packed k elements per lane.
|
||||
sparse_idx = 0x88888888 | ((k % 2) << (v * 2));
|
||||
}
|
||||
}
|
||||
|
||||
for(index_t v = 0; v < b_vec_size; ++v)
|
||||
{
|
||||
auto b_coords = BRegMap::calc_matrix_indices_from_lane_vector(lane, v);
|
||||
if(b_coords[0] == n && b_coords[1] == k)
|
||||
{
|
||||
b_frag[v] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(MmaOpTraits<MmaOp>::IsSparse)
|
||||
{
|
||||
c_frag = MmaOp::exec(a_frag, b_frag, c_frag, sparse_idx);
|
||||
}
|
||||
else
|
||||
{
|
||||
c_frag = MmaOp::exec(a_frag, b_frag, c_frag);
|
||||
}
|
||||
|
||||
// TODO: this tolerance might not be suitable for all data types and
|
||||
// should be revisited if we add more configurations
|
||||
const float tolerance = 1.0e-1f;
|
||||
index_t err = 0;
|
||||
|
||||
for(index_t v = 0; v < c_vec_size; ++v)
|
||||
{
|
||||
auto c_coords = CRegMap::calc_matrix_indices_from_lane_vector(lane, v);
|
||||
|
||||
const float expected = (c_coords[0] == m && c_coords[1] == n) ? 1 : 0;
|
||||
const float value = static_cast<float>(c_frag[v]);
|
||||
if(std::fabs(value - expected) > tolerance)
|
||||
{
|
||||
err = 1;
|
||||
}
|
||||
}
|
||||
|
||||
const uint32_t any_err = __any(err);
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
error_flags[case_idx] = any_err;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Test driver: runs the test for a given MMA configuration.
|
||||
*
|
||||
* The testlaunches (mkn) test cases (one per block) to check all possible positions of the "1" in
|
||||
* the A/B tensors.
|
||||
* 1. Constructs A and B tensors with a single 1 at A(m,k) and B(k,n).
|
||||
* 2. Executes MMA intrinsic to compute C tensor.
|
||||
* 3. Checks if C has the 1 in the expected position.
|
||||
*
|
||||
* @tparam MmaOp Intrinsic (amdgcn_mma) to be tested
|
||||
*/
|
||||
template <typename MmaOp> // TODO: C++20 concept for MmaOp
|
||||
void run_mma_layout_test()
|
||||
{
|
||||
EXPECT_TRUE(MmaOpTraits<MmaOp>::IsSupported) << "Unsupported MmaOp! Bad MmaOp in list!\n";
|
||||
|
||||
int device_count = 0;
|
||||
hipDevice_t device{};
|
||||
HIP_CHECK_ERROR(hipGetDevice(&device));
|
||||
HIP_CHECK_ERROR(hipGetDeviceCount(&device_count));
|
||||
EXPECT_TRUE(device_count > 0) << "No device found!";
|
||||
|
||||
hipDeviceProp_t props{};
|
||||
HIP_CHECK_ERROR(hipGetDeviceProperties(&props, device));
|
||||
EXPECT_EQ(props.warpSize, static_cast<int>(MmaOp::WaveSize))
|
||||
<< "Device wavesize " << props.warpSize << " != Mma wavesize " << MmaOp::WaveSize;
|
||||
|
||||
constexpr uint32_t total_cases = MmaOp::kM * MmaOp::kN * MmaOp::kK;
|
||||
ck_tile::DeviceMem d_errors(total_cases * sizeof(uint32_t));
|
||||
std::vector<uint32_t> h_errors(total_cases, 0u);
|
||||
|
||||
auto* d_error_ptr = static_cast<uint32_t*>(d_errors.GetDeviceBuffer());
|
||||
|
||||
(void)hipGetLastError();
|
||||
|
||||
using Kernel = MmaLayoutTestKernel<MmaOp>;
|
||||
|
||||
ck_tile::launch_kernel(
|
||||
ck_tile::stream_config{nullptr, false, 0, 0, 1},
|
||||
ck_tile::make_kernel(Kernel{}, dim3(total_cases), dim3(MmaOp::WaveSize), 0, d_error_ptr));
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpyAsync(
|
||||
h_errors.data(), d_error_ptr, d_errors.GetBufferSize(), hipMemcpyDeviceToHost));
|
||||
HIP_CHECK_ERROR(hipStreamSynchronize(nullptr));
|
||||
|
||||
for(uint32_t case_idx = 0; case_idx < total_cases; ++case_idx)
|
||||
{
|
||||
const uint32_t m = case_idx / (MmaOp::kK * MmaOp::kN);
|
||||
const uint32_t k = (case_idx / MmaOp::kN) % MmaOp::kK;
|
||||
const uint32_t n = case_idx % MmaOp::kN;
|
||||
|
||||
EXPECT_EQ(h_errors[case_idx], 0u) << "Mismatch for m=" << m << " k=" << k << " n=" << n;
|
||||
}
|
||||
}
|
||||
|
||||
// Lists of intrinsics to test.
|
||||
// clang-format off
|
||||
using Gfx9Intrinsics = ::testing::Types<
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_16x16x16f16
|
||||
amdgcn_mma<F16, F16, F32, 64u, 32u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_32x32x4f16
|
||||
amdgcn_mma<F16, F16, F32, 32u, 64u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_32x32x4f16
|
||||
amdgcn_mma<F16, F16, F32, 64u, 4u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_4x4x4f16
|
||||
amdgcn_mma<F16, F16, F32, 4u, 64u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE> // mfma_f32_4x4x4f16
|
||||
>;
|
||||
using Gfx942Intrinsics = ::testing::Types<
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, DefaultSparseMfmaCtrlFlags, Target942, MmaOpFamily::SPARSE> // smfmac_f32_16x16x32_f16
|
||||
>;
|
||||
using Gfx950Intrinsics = ::testing::Types<
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, DefaultMfmaCtrlFlags, Target950, MmaOpFamily::DENSE> // mfma_f32_16x16x32_f16
|
||||
>;
|
||||
using Gfx11Intrinsics = ::testing::Types<
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags<F16, F16, F32>, Target11, MmaOpFamily::DENSE> // wmma_f32_16x16x16_f16_w32
|
||||
>;
|
||||
using Gfx12Intrinsics = ::testing::Types<
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags<F16, F16, F32>, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_f16_w32_gfx12
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, DefaultSparseMfmaCtrlFlags, Target12, MmaOpFamily::SPARSE> // swmmac_f32_16x16x32_f16_w32
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
template <typename MmaOp>
|
||||
class TestMmaLayout : public ::testing::Test
|
||||
{
|
||||
};
|
||||
} // namespace
|
||||
@@ -0,0 +1,6 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_amdgcn_mma_layout.inc"
|
||||
TYPED_TEST_SUITE(TestMmaLayout, Gfx11Intrinsics);
|
||||
TYPED_TEST(TestMmaLayout, Gfx11Intrinsics) { run_mma_layout_test<TypeParam>(); }
|
||||
@@ -0,0 +1,6 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_amdgcn_mma_layout.inc"
|
||||
TYPED_TEST_SUITE(TestMmaLayout, Gfx12Intrinsics);
|
||||
TYPED_TEST(TestMmaLayout, Gfx12Intrinsics) { run_mma_layout_test<TypeParam>(); }
|
||||
@@ -0,0 +1,6 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_amdgcn_mma_layout.inc"
|
||||
TYPED_TEST_SUITE(TestMmaLayout, Gfx9Intrinsics);
|
||||
TYPED_TEST(TestMmaLayout, Gfx9Intrinsics) { run_mma_layout_test<TypeParam>(); }
|
||||
@@ -0,0 +1,6 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_amdgcn_mma_layout.inc"
|
||||
TYPED_TEST_SUITE(TestMmaLayout, Gfx942Intrinsics);
|
||||
TYPED_TEST(TestMmaLayout, Gfx942Intrinsics) { run_mma_layout_test<TypeParam>(); }
|
||||
@@ -0,0 +1,6 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_amdgcn_mma_layout.inc"
|
||||
TYPED_TEST_SUITE(TestMmaLayout, Gfx950Intrinsics);
|
||||
TYPED_TEST(TestMmaLayout, Gfx950Intrinsics) { run_mma_layout_test<TypeParam>(); }
|
||||
@@ -1,306 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_selector.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp"
|
||||
|
||||
#include <cassert>
|
||||
#include <cstdint>
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace ck_tile;
|
||||
|
||||
/**
|
||||
* @class RegisterMapTraits
|
||||
* @brief Traits class that defines tile_distribution_encoding for each MmaOp
|
||||
* @tparam MmaOp amdgcn_mma specialization
|
||||
*/
|
||||
template <typename MmaOp>
|
||||
struct RegisterMapTraits
|
||||
{
|
||||
static_assert(sizeof(MmaOp) == 0, "RegisterMapTraits requires a specialization");
|
||||
};
|
||||
|
||||
/**
|
||||
* @class RegisterMap
|
||||
* @brief Uses specialized RegisterMapTraits to get the encoding
|
||||
* @tparam MmaOp amdgcn_mma specialization
|
||||
*/
|
||||
template <typename MmaOp>
|
||||
struct RegisterMap
|
||||
{
|
||||
using Traits = RegisterMapTraits<MmaOp>;
|
||||
|
||||
using AMap = core::arch::mma::TileDistrEncRegMap<typename Traits::AWarpDstrEncoding>;
|
||||
using BMap = core::arch::mma::TileDistrEncRegMap<typename Traits::BWarpDstrEncoding>;
|
||||
using CMap = core::arch::mma::TileDistrEncRegMap<typename Traits::CWarpDstrEncoding>;
|
||||
|
||||
CK_TILE_HOST_DEVICE static auto Register2AMap(const uint32_t lane, const uint32_t vecIdx)
|
||||
{
|
||||
return AMap::calc_matrix_indices_from_lane_vector(static_cast<index_t>(lane),
|
||||
static_cast<index_t>(vecIdx));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static auto Register2BMap(const uint32_t lane, const uint32_t vecIdx)
|
||||
{
|
||||
return BMap::calc_matrix_indices_from_lane_vector(static_cast<index_t>(lane),
|
||||
static_cast<index_t>(vecIdx));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static auto Register2CMap(const uint32_t lane, const uint32_t vecIdx)
|
||||
{
|
||||
return CMap::calc_matrix_indices_from_lane_vector(static_cast<index_t>(lane),
|
||||
static_cast<index_t>(vecIdx));
|
||||
}
|
||||
};
|
||||
|
||||
// ====================== Specializations per target =====================
|
||||
|
||||
/**
|
||||
* @brief RegisterMapTraits for GFX12 WMMA 16x16x16_F16_F16_F32_GFX12
|
||||
*/
|
||||
template <typename CtrlFlags,
|
||||
typename CompilerTarget,
|
||||
ck_tile::core::arch::mma::MmaOpFamily OpFamily_>
|
||||
struct RegisterMapTraits<ck_tile::core::arch::mma::amdgcn_mma<
|
||||
ck_tile::fp16_t,
|
||||
ck_tile::fp16_t,
|
||||
ck_tile::fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
OpFamily_,
|
||||
ck_tile::core::arch::enable_if_target_family_gfx12_t<CompilerTarget>>>
|
||||
{
|
||||
using MmaOp = ck_tile::core::arch::mma::amdgcn_mma<ck_tile::fp16_t,
|
||||
ck_tile::fp16_t,
|
||||
ck_tile::fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
OpFamily_>;
|
||||
|
||||
static constexpr index_t AVecSize = vector_traits<typename MmaOp::AVecType>::vector_size;
|
||||
static constexpr index_t BVecSize = vector_traits<typename MmaOp::BVecType>::vector_size;
|
||||
static constexpr index_t CVecSize = vector_traits<typename MmaOp::CVecType>::vector_size;
|
||||
|
||||
using kABPs2RHssMajor = sequence<2, 1>;
|
||||
using kABPs2RHssMinor = sequence<1, 0>;
|
||||
using kABYs2RHsMajor = sequence<2, 2>;
|
||||
using kABYs2RHsMinor = sequence<0, 2>;
|
||||
using kCPs2RHssMajor = sequence<1, 2>;
|
||||
using kCPs2RHssMinor = sequence<1, 0>;
|
||||
using kCYs2RHsMajor = sequence<1, 1>;
|
||||
using kCYs2RHsMinor = sequence<0, 2>;
|
||||
|
||||
// TODO: remove these and fix constants in amdgcn_mma
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABK0PerLane = 1;
|
||||
static constexpr index_t kABKLane = 2;
|
||||
static constexpr index_t kABK1PerLane = 8;
|
||||
static constexpr index_t kCMLane = 2;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
static constexpr index_t kCM0PerLane = 1;
|
||||
static constexpr index_t kCM1PerLane = 8;
|
||||
|
||||
using AWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<kAMLane>, sequence<kABK0PerLane, kABKLane, kABK1PerLane>>, // <16>, <1, 2, 8>
|
||||
tuple<kABPs2RHssMajor>,
|
||||
tuple<kABPs2RHssMinor>,
|
||||
kABYs2RHsMajor,
|
||||
kABYs2RHsMinor>;
|
||||
|
||||
using BWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<kBNLane>, sequence<kABK0PerLane, kABKLane, kABK1PerLane>>, // <16>, <1, 2, 8>
|
||||
tuple<kABPs2RHssMajor>,
|
||||
tuple<kABPs2RHssMinor>,
|
||||
kABYs2RHsMajor,
|
||||
kABYs2RHsMinor>;
|
||||
|
||||
using CWarpDstrEncoding =
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<kCM0PerLane, kCMLane, kCM1PerLane>,
|
||||
sequence<kCNLane>>, // <1, 2, 8>, <16>
|
||||
tuple<kCPs2RHssMajor>,
|
||||
tuple<kCPs2RHssMinor>,
|
||||
kCYs2RHsMajor,
|
||||
kCYs2RHsMinor>;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief RegisterMapTraits for GFX9 MFMA 16x16x16_F16_F16_F32_GFX9
|
||||
*/
|
||||
template <typename CtrlFlags,
|
||||
typename CompilerTarget,
|
||||
ck_tile::core::arch::mma::MmaOpFamily OpFamily_>
|
||||
struct RegisterMapTraits<ck_tile::core::arch::mma::amdgcn_mma<
|
||||
ck_tile::fp16_t,
|
||||
ck_tile::fp16_t,
|
||||
ck_tile::fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
OpFamily_,
|
||||
ck_tile::core::arch::enable_if_target_family_gfx9_t<CompilerTarget>>>
|
||||
{
|
||||
using MmaOp = ck_tile::core::arch::mma::amdgcn_mma<ck_tile::fp16_t,
|
||||
ck_tile::fp16_t,
|
||||
ck_tile::fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
OpFamily_>;
|
||||
|
||||
static constexpr index_t AVecSize = vector_traits<typename MmaOp::AVecType>::vector_size;
|
||||
static constexpr index_t BVecSize = vector_traits<typename MmaOp::BVecType>::vector_size;
|
||||
static constexpr index_t CVecSize = vector_traits<typename MmaOp::CVecType>::vector_size;
|
||||
|
||||
using kABPs2RHssMajor = sequence<2, 1>;
|
||||
using kABPs2RHssMinor = sequence<0, 0>;
|
||||
using kABYs2RHsMajor = sequence<2>;
|
||||
using kABYs2RHsMinor = sequence<1>;
|
||||
using kCPs2RHssMajor = sequence<1, 2>;
|
||||
using kCPs2RHssMinor = sequence<0, 0>;
|
||||
using kCYs2RHsMajor = sequence<1>;
|
||||
using kCYs2RHsMinor = sequence<1>;
|
||||
|
||||
// TODO: remove these and fix constants in amdgcn_mma
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABKLane = 4;
|
||||
static constexpr index_t kABKPerLane = 4;
|
||||
static constexpr index_t kCMLane = 4;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
static constexpr index_t kCM0PerLane = 1;
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
using AWarpDstrEncoding =
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<kAMLane>, sequence<kABKLane, kABKPerLane>>,
|
||||
tuple<kABPs2RHssMajor>,
|
||||
tuple<kABPs2RHssMinor>,
|
||||
kABYs2RHsMajor,
|
||||
kABYs2RHsMinor>;
|
||||
|
||||
using BWarpDstrEncoding =
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<kBNLane>, sequence<kABKLane, kABKPerLane>>,
|
||||
tuple<kABPs2RHssMajor>,
|
||||
tuple<kABPs2RHssMinor>,
|
||||
kABYs2RHsMajor,
|
||||
kABYs2RHsMinor>;
|
||||
|
||||
using CWarpDstrEncoding =
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<kCMLane, kCM1PerLane>, sequence<kCNLane>>,
|
||||
tuple<kCPs2RHssMajor>,
|
||||
tuple<kCPs2RHssMinor>,
|
||||
kCYs2RHsMajor,
|
||||
kCYs2RHsMinor>;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief RegisterMapTraits for GFX11 WMMA 16x16x16_F16_F16_F32_GFX11
|
||||
*/
|
||||
template <typename CtrlFlags,
|
||||
typename CompilerTarget,
|
||||
ck_tile::core::arch::mma::MmaOpFamily OpFamily_>
|
||||
struct RegisterMapTraits<ck_tile::core::arch::mma::amdgcn_mma<
|
||||
ck_tile::fp16_t,
|
||||
ck_tile::fp16_t,
|
||||
ck_tile::fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
OpFamily_,
|
||||
ck_tile::core::arch::enable_if_target_family_gfx11_t<CompilerTarget>>>
|
||||
{
|
||||
using MmaOp = ck_tile::core::arch::mma::amdgcn_mma<ck_tile::fp16_t,
|
||||
ck_tile::fp16_t,
|
||||
ck_tile::fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
OpFamily_>;
|
||||
|
||||
static constexpr index_t AVecSize = vector_traits<typename MmaOp::AVecType>::vector_size;
|
||||
static constexpr index_t BVecSize = vector_traits<typename MmaOp::BVecType>::vector_size;
|
||||
static constexpr index_t CVecSize = vector_traits<typename MmaOp::CVecType>::vector_size;
|
||||
|
||||
using kABPs2RHssMajor = sequence<0, 1>;
|
||||
using kABPs2RHssMinor = sequence<0, 0>;
|
||||
using kABYs2RHsMajor = sequence<2>;
|
||||
using kABYs2RHsMinor = sequence<0>;
|
||||
using kCPs2RHssMajor = sequence<1, 2>;
|
||||
using kCPs2RHssMinor = sequence<1, 0>;
|
||||
using kCYs2RHsMajor = sequence<1>;
|
||||
using kCYs2RHsMinor = sequence<0>;
|
||||
|
||||
// TODO: remove these and fix constants in amdgcn_mma
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABK0PerLane = 1;
|
||||
static constexpr index_t kABKLane = 1;
|
||||
static constexpr index_t kABK1PerLane = 16;
|
||||
static constexpr index_t kCMLane = 2;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
static constexpr index_t kCM0PerLane = 8;
|
||||
static constexpr index_t kCM1PerLane = 1;
|
||||
|
||||
using AWarpDstrEncoding =
|
||||
tile_distribution_encoding<sequence<2>, // kRepeat
|
||||
tuple<sequence<kAMLane>, sequence<kABK1PerLane>>,
|
||||
tuple<kABPs2RHssMajor>,
|
||||
tuple<kABPs2RHssMinor>,
|
||||
kABYs2RHsMajor,
|
||||
kABYs2RHsMinor>;
|
||||
|
||||
using BWarpDstrEncoding =
|
||||
tile_distribution_encoding<sequence<2>, // kRepeat
|
||||
tuple<sequence<kBNLane>, sequence<kABK1PerLane>>,
|
||||
tuple<kABPs2RHssMajor>,
|
||||
tuple<kABPs2RHssMinor>,
|
||||
kABYs2RHsMajor,
|
||||
kABYs2RHsMinor>;
|
||||
|
||||
using CWarpDstrEncoding =
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<kCM0PerLane, kCMLane>, sequence<kCNLane>>,
|
||||
tuple<kCPs2RHssMajor>,
|
||||
tuple<kCPs2RHssMinor>,
|
||||
kCYs2RHsMajor,
|
||||
kCYs2RHsMinor>;
|
||||
};
|
||||
|
||||
// ========================================================================
|
||||
|
||||
} // namespace
|
||||
Reference in New Issue
Block a user