mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
Fix warnings during wrapper docs generation (#1192)
* Fix warnings during wrapper docs generation * Fixes
This commit is contained in:
@@ -45,3 +45,5 @@ for sphinx_var in ROCmDocs.SPHINX_VARS:
|
||||
|
||||
extensions += ['sphinxcontrib.bibtex']
|
||||
bibtex_bibfiles = ['refs.bib']
|
||||
|
||||
cpp_id_attributes = ["__global__", "__device__", "__host__"]
|
||||
|
||||
@@ -63,30 +63,31 @@ Advanced examples:
|
||||
Layout
|
||||
-------------------------------------
|
||||
|
||||
.. doxygenstruct:: ck::wrapper::Layout
|
||||
.. doxygenstruct:: Layout
|
||||
|
||||
-------------------------------------
|
||||
Layout helpers
|
||||
-------------------------------------
|
||||
|
||||
.. doxygenfile:: layout_utils.hpp
|
||||
.. doxygenfile:: include/ck/wrapper/utils/layout_utils.hpp
|
||||
|
||||
-------------------------------------
|
||||
Tensor
|
||||
-------------------------------------
|
||||
|
||||
.. doxygenstruct:: ck::wrapper::Tensor
|
||||
.. doxygenstruct:: Tensor
|
||||
|
||||
-------------------------------------
|
||||
Tensor helpers
|
||||
-------------------------------------
|
||||
|
||||
.. doxygenfile:: tensor_utils.hpp
|
||||
.. doxygenfile:: include/ck/wrapper/utils/tensor_utils.hpp
|
||||
|
||||
.. doxygenfile:: tensor_partition.hpp
|
||||
.. doxygenfile:: include/ck/wrapper/utils/tensor_partition.hpp
|
||||
|
||||
-------------------------------------
|
||||
Operations
|
||||
-------------------------------------
|
||||
|
||||
.. doxygenfile:: copy.hpp
|
||||
.. doxygenfile:: include/ck/wrapper/operations/copy.hpp
|
||||
.. doxygenfile:: include/ck/wrapper/operations/gemm.hpp
|
||||
|
||||
@@ -5,8 +5,11 @@
|
||||
|
||||
#include "ck/wrapper/utils/layout_utils.hpp"
|
||||
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond INTERNAL
|
||||
namespace ck {
|
||||
namespace wrapper {
|
||||
/// @endcond
|
||||
|
||||
/**
|
||||
* \brief Layout wrapper that performs the tensor descriptor logic.
|
||||
@@ -19,6 +22,8 @@ namespace wrapper {
|
||||
template <typename Shape, typename UnrolledDescriptorType>
|
||||
struct Layout
|
||||
{
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond INTERNAL
|
||||
private:
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -246,6 +251,7 @@ struct Layout
|
||||
using Descriptor1dType =
|
||||
remove_cvref_t<decltype(MakeMerge1d(Shape{}, UnrolledDescriptorType{}))>;
|
||||
using DefaultIdxsTupleType = remove_cvref_t<decltype(GenerateDefaultIdxsTuple(Shape{}))>;
|
||||
/// @endcond
|
||||
|
||||
public:
|
||||
/**
|
||||
@@ -454,6 +460,8 @@ struct Layout
|
||||
return unrolled_descriptor_;
|
||||
}
|
||||
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond INTERNAL
|
||||
private:
|
||||
// All dimensions are unrolled
|
||||
UnrolledDescriptorType unrolled_descriptor_;
|
||||
@@ -466,6 +474,7 @@ struct Layout
|
||||
// Descriptor1dType lengths: (8)
|
||||
// MergedNestsDescriptorType lengths: (4, 2)
|
||||
const Shape shape_;
|
||||
/// @endcond
|
||||
};
|
||||
|
||||
} // namespace wrapper
|
||||
|
||||
@@ -10,8 +10,11 @@
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond INTERNAL
|
||||
namespace ck {
|
||||
namespace wrapper {
|
||||
/// @endcond
|
||||
|
||||
/**
|
||||
* \brief Perform generic copy between two tensors partitions (threadwise copy).
|
||||
|
||||
395
include/ck/wrapper/operations/gemm.hpp
Normal file
395
include/ck/wrapper/operations/gemm.hpp
Normal file
@@ -0,0 +1,395 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/wrapper/utils/tensor_utils.hpp"
|
||||
#include "ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp"
|
||||
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
|
||||
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond INTERNAL
|
||||
namespace ck {
|
||||
namespace wrapper {
|
||||
/// @endcond
|
||||
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond INTERNAL
|
||||
namespace {
|
||||
namespace detail {
|
||||
/**
|
||||
* \brief Create block descriptor (K0, MPerBlock or NPerBlock, K1).
|
||||
*
|
||||
*
|
||||
* \tparam K1 The number of K-dim elements that are packed together as a separate logical dimension.
|
||||
* \tparam TileLayout Tensor data tile layout (M,K) or (N,K).
|
||||
*
|
||||
* \return Block descriptor (K0, MPerBlock or NPerBlock, K1)
|
||||
*/
|
||||
template <index_t K1, typename TileLayout>
|
||||
__device__ constexpr auto GetBlockDescriptor()
|
||||
{
|
||||
using TileLayoutShape = typename TileLayout::LayoutShape;
|
||||
using TileLayoutDescriptor = typename TileLayout::LayoutUnrolledDescriptorType;
|
||||
|
||||
constexpr auto K0PerBlock = Number<size<1>(TileLayoutShape{})>{} / Number<K1>{};
|
||||
// MPerBlock or NPerBlock
|
||||
constexpr auto Dim0 = Number<size<0>(TileLayoutShape{})>{};
|
||||
|
||||
constexpr auto a_block_desc_k0_m_k1 = transform_tensor_descriptor(
|
||||
TileLayoutDescriptor{},
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0PerBlock, Number<K1>{})),
|
||||
make_pass_through_transform(Dim0)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return a_block_desc_k0_m_k1;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace
|
||||
/// @endcond
|
||||
|
||||
/**
|
||||
* \brief Perform blockwise gemm xdl on tensors stored in lds. Result will be
|
||||
* stored in Vgpr register. A data layout must be (MPerBlock, KPerBlock) or
|
||||
* (K0PerBlock, MPerBlock, K1) and B data layout must be (NPerBlock, KPerBlock)
|
||||
* or (K0PerBlock, NPerBlock, K1).
|
||||
*
|
||||
* \note C output Vgpr register layout (8D):
|
||||
* - MXdlPerWave - The number of MFMA instructions run by single wave in M
|
||||
* dimension per tile.
|
||||
* - NXdlPerWave - The number of MFMA instructions run by single wave in N
|
||||
* dimension per tile.
|
||||
* - MWave - Equals to 1 since this is for single wave.
|
||||
* - NWave - Equals to 1 since this is for single wave.
|
||||
* - NumGroupsPerBlock - Mfma instruction internal layout (depeneds on the
|
||||
* instruction size).
|
||||
* - NumInputsBlock - Mfma instruction internal layout (depeneds on the
|
||||
* instruction size).
|
||||
* - GroupSize - Mfma instruction internal layout (depeneds on the
|
||||
* instruction size).
|
||||
* - NumThreadsPerBlock - Mfma instruction internal layout (depeneds on the
|
||||
* instruction size).
|
||||
*
|
||||
* \tparam DataType Input data types.
|
||||
* \tparam BlockSize Tensor to pad.
|
||||
* \tparam GemmTraits Traits of gemm xdl operation.
|
||||
* \param a_local_tile_tensor A tensor in LDS memory for blockwise gemm
|
||||
* (MPerBlock, KPerBlock) or (K0PerBlock, MPerBlock, K1) layout.
|
||||
* \param b_local_tile_tensor B tensor in LDS memory for blockwise gemm
|
||||
* (NPerBlock, KPerBlock) or (K0PerBlock, NPerBlock, K1) layout.
|
||||
* \param c_reg_tensor C tensor VGPR memory for blockwise gemm.
|
||||
*/
|
||||
template <typename DataType,
|
||||
index_t BlockSize,
|
||||
typename GemmTraits,
|
||||
typename ATensorType,
|
||||
typename BTensorType,
|
||||
typename CTensorType>
|
||||
__device__ void blockwise_gemm_xdl(const ATensorType& a_local_tile_tensor,
|
||||
const BTensorType& b_local_tile_tensor,
|
||||
CTensorType& c_reg_tensor)
|
||||
{
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
static_assert(ATensorType::TensorBufferAddressSpace == MemoryTypeEnum::Lds);
|
||||
static_assert(BTensorType::TensorBufferAddressSpace == MemoryTypeEnum::Lds);
|
||||
static_assert(CTensorType::TensorBufferAddressSpace == MemoryTypeEnum::Vgpr);
|
||||
static_assert(is_same_v<DataType, typename ATensorType::TensorElementType>);
|
||||
static_assert(is_same_v<DataType, typename BTensorType::TensorElementType>);
|
||||
|
||||
constexpr bool is_integer =
|
||||
is_same_v<DataType, int8_t> || is_same_v<DataType, int16_t> || is_same_v<DataType, int32_t>;
|
||||
using GemmAccDataType = std::conditional_t<is_integer, int32_t, float>;
|
||||
|
||||
using ATileLayout = remove_cvref_t<decltype(layout(a_local_tile_tensor))>;
|
||||
using BTileLayout = remove_cvref_t<decltype(layout(b_local_tile_tensor))>;
|
||||
|
||||
static_assert(typename ATileLayout::LayoutShape{}.Size() ==
|
||||
typename BTileLayout::LayoutShape{}.Size());
|
||||
constexpr bool is_3d_desc = typename ATileLayout::LayoutShape{}.Size() == I3;
|
||||
|
||||
using ABlockDesc_K0_M_K1_Type =
|
||||
conditional_t<is_3d_desc,
|
||||
typename ATileLayout::LayoutUnrolledDescriptorType,
|
||||
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>())>;
|
||||
using BBlockDesc_K0_N_K1_Type =
|
||||
conditional_t<is_3d_desc,
|
||||
typename BTileLayout::LayoutUnrolledDescriptorType,
|
||||
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>())>;
|
||||
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
DataType,
|
||||
DataType,
|
||||
GemmAccDataType,
|
||||
ABlockDesc_K0_M_K1_Type,
|
||||
BBlockDesc_K0_N_K1_Type,
|
||||
GemmTraits::MPerXDL,
|
||||
GemmTraits::NPerXDL,
|
||||
GemmTraits::MXdlPerWave,
|
||||
GemmTraits::NXdlPerWave,
|
||||
GemmTraits::K1>
|
||||
blockwise_gemm_xdl_op{};
|
||||
|
||||
blockwise_gemm_xdl_op.Run(
|
||||
a_local_tile_tensor.GetBuffer(), b_local_tile_tensor.GetBuffer(), c_reg_tensor.GetBuffer());
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Create local partition per thread for C tensor.
|
||||
*
|
||||
* \note C output global memory layout (8D):
|
||||
* - MXdlPerWave - The number of MFMA instructions run by single wave in M
|
||||
* dimension.
|
||||
* - NXdlPerWave - The number of MFMA instructions run by single wave in N
|
||||
* dimension.
|
||||
* - MWave - The number of waves in single tile M dimension per tile.
|
||||
* - NWave - The number of waves in single tile N dimension per tile.
|
||||
* - NumGroupsPerBlock - Mfma instruction internal layout (depeneds on the
|
||||
* instruction size).
|
||||
* - NumInputsBlock - Mfma instruction internal layout (depeneds on the
|
||||
* instruction size).
|
||||
* - GroupSize - Mfma instruction internal layout (depeneds on the
|
||||
* instruction size).
|
||||
* - NumThreadsPerBlock - Mfma instruction internal layout (depeneds on the
|
||||
* instruction size).
|
||||
*
|
||||
* \tparam DataType Input data types.
|
||||
* \tparam ATileLayout A tensor layout.
|
||||
* \tparam BTileLayout B tensor layout.
|
||||
* \tparam BlockSize Number of threads in block.
|
||||
* \tparam GemmTraits Traits of gemm xdl operation.
|
||||
* \param c_local_tile_tensor C tensor in LDS memory for blockwise gemm
|
||||
* (MPerBlock, NPerBlock) layout.
|
||||
*
|
||||
* \return Partition c tensor for blockwise gemm.
|
||||
*/
|
||||
template <typename DataType,
|
||||
typename ATileLayout,
|
||||
typename BTileLayout,
|
||||
index_t BlockSize,
|
||||
typename GemmTraits,
|
||||
typename CTensorType>
|
||||
__host__ __device__ constexpr auto
|
||||
make_blockwise_gemm_xdl_c_local_partition(CTensorType& c_local_tile_tensor)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
|
||||
static_assert(typename ATileLayout::LayoutShape{}.Size() ==
|
||||
typename BTileLayout::LayoutShape{}.Size());
|
||||
|
||||
constexpr bool is_integer =
|
||||
is_same_v<DataType, int8_t> || is_same_v<DataType, int16_t> || is_same_v<DataType, int32_t>;
|
||||
using GemmAccDataType = std::conditional_t<is_integer, int32_t, float>;
|
||||
|
||||
constexpr bool is_3d_desc = typename ATileLayout::LayoutShape{}.Size() == I3;
|
||||
using ABlockDesc_K0_M_K1_Type =
|
||||
conditional_t<is_3d_desc,
|
||||
typename ATileLayout::LayoutUnrolledDescriptorType,
|
||||
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>())>;
|
||||
using BBlockDesc_K0_N_K1_Type =
|
||||
conditional_t<is_3d_desc,
|
||||
typename BTileLayout::LayoutUnrolledDescriptorType,
|
||||
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>())>;
|
||||
|
||||
using BlockwiseGemmXdlops =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
DataType,
|
||||
DataType,
|
||||
GemmAccDataType,
|
||||
ABlockDesc_K0_M_K1_Type,
|
||||
BBlockDesc_K0_N_K1_Type,
|
||||
GemmTraits::MPerXDL,
|
||||
GemmTraits::NPerXDL,
|
||||
GemmTraits::MXdlPerWave,
|
||||
GemmTraits::NXdlPerWave,
|
||||
GemmTraits::K1>;
|
||||
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
|
||||
BlockwiseGemmXdlops::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0);
|
||||
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1);
|
||||
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2);
|
||||
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3);
|
||||
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4);
|
||||
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5);
|
||||
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6);
|
||||
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7);
|
||||
|
||||
// Calculate offset on grid
|
||||
const auto c_thread_mtx_on_block =
|
||||
BlockwiseGemmXdlops::CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
|
||||
|
||||
const index_t m_thread_data_on_grid =
|
||||
c_local_tile_tensor.GetMultiIdxOffsets()[I0] + c_thread_mtx_on_block[I0];
|
||||
|
||||
const index_t n_thread_data_on_grid =
|
||||
c_local_tile_tensor.GetMultiIdxOffsets()[I1] + c_thread_mtx_on_block[I1];
|
||||
|
||||
const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
|
||||
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto m_thread_data_on_grid_idx =
|
||||
m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(m_thread_data_on_grid));
|
||||
|
||||
const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor =
|
||||
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto n_thread_data_on_grid_idx =
|
||||
n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(n_thread_data_on_grid));
|
||||
// Create partition shape based on descriptor dims.
|
||||
const auto partition_shape = make_tuple(M0, N0, I1, I1, M2, I1, M4, I1);
|
||||
|
||||
const auto partition_desc = BlockwiseGemmXdlops::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(
|
||||
layout(c_local_tile_tensor).GetUnrolledDescriptor());
|
||||
|
||||
const auto lower_upper_dims =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<8>{});
|
||||
|
||||
auto sliced_desc = transform_tensor_descriptor(
|
||||
partition_desc,
|
||||
make_tuple(
|
||||
make_slice_transform(partition_shape.At(Number<0>{}),
|
||||
m_thread_data_on_grid_idx[I0],
|
||||
partition_shape.At(Number<0>{}) + m_thread_data_on_grid_idx[I0]),
|
||||
make_slice_transform(partition_shape.At(Number<1>{}),
|
||||
n_thread_data_on_grid_idx[I0],
|
||||
partition_shape.At(Number<1>{}) + n_thread_data_on_grid_idx[I0]),
|
||||
make_slice_transform(partition_shape.At(Number<2>{}),
|
||||
m_thread_data_on_grid_idx[I1],
|
||||
partition_shape.At(Number<2>{}) + m_thread_data_on_grid_idx[I1]),
|
||||
make_slice_transform(partition_shape.At(Number<3>{}),
|
||||
n_thread_data_on_grid_idx[I1],
|
||||
partition_shape.At(Number<3>{}) + n_thread_data_on_grid_idx[I1]),
|
||||
make_slice_transform(partition_shape.At(Number<4>{}),
|
||||
m_thread_data_on_grid_idx[I2],
|
||||
partition_shape.At(Number<4>{}) + m_thread_data_on_grid_idx[I2]),
|
||||
make_slice_transform(partition_shape.At(Number<5>{}),
|
||||
m_thread_data_on_grid_idx[I3],
|
||||
partition_shape.At(Number<5>{}) + m_thread_data_on_grid_idx[I3]),
|
||||
make_slice_transform(partition_shape.At(Number<6>{}),
|
||||
m_thread_data_on_grid_idx[I4],
|
||||
partition_shape.At(Number<6>{}) + m_thread_data_on_grid_idx[I4]),
|
||||
make_slice_transform(partition_shape.At(Number<7>{}),
|
||||
n_thread_data_on_grid_idx[I2],
|
||||
partition_shape.At(Number<7>{}) + n_thread_data_on_grid_idx[I2])),
|
||||
lower_upper_dims,
|
||||
lower_upper_dims);
|
||||
|
||||
const auto partition_layout =
|
||||
Layout<remove_reference_t<decltype(partition_shape)>, decltype(sliced_desc)>(
|
||||
partition_shape, sliced_desc);
|
||||
auto partition_tensor = make_tensor<CTensorType::TensorBufferAddressSpace>(
|
||||
c_local_tile_tensor.GetPointer(), partition_layout);
|
||||
return partition_tensor;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Create local partition per thread for C tensor.
|
||||
*
|
||||
* \note C output Vgpr register layout (8D):
|
||||
* - MXdlPerWave - The number of MFMA instructions run by single wave in M
|
||||
* dimension per tile.
|
||||
* - NXdlPerWave - The number of MFMA instructions run by single wave in N
|
||||
* dimension per tile.
|
||||
* - MWave - Equals to 1 since this is for single wave.
|
||||
* - NWave - Equals to 1 since this is for single wave.
|
||||
* - NumGroupsPerBlock - Mfma instruction internal layout (depeneds on the
|
||||
* instruction size).
|
||||
* - NumInputsBlock - Mfma instruction internal layout (depeneds on the
|
||||
* instruction size).
|
||||
* - GroupSize - Mfma instruction internal layout (depeneds on the
|
||||
* instruction size).
|
||||
* - NumThreadsPerBlock - Mfma instruction internal layout (depeneds on the
|
||||
* instruction size).
|
||||
*
|
||||
* \tparam DataType Input data types.
|
||||
* \tparam ATileLayout A tensor layout.
|
||||
* \tparam BTileLayout B tensor layout.
|
||||
* \tparam BlockSize Number of threads in block.
|
||||
* \tparam GemmTraits Traits of gemm xdl operation.
|
||||
*
|
||||
* \return Vgpr c tensor for blockwise gemm.
|
||||
*/
|
||||
template <typename DataType,
|
||||
typename ATileLayout,
|
||||
typename BTileLayout,
|
||||
index_t BlockSize,
|
||||
typename GemmTraits>
|
||||
__host__ __device__ constexpr auto make_blockwise_gemm_xdl_c_vgpr()
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
|
||||
static_assert(typename ATileLayout::LayoutShape{}.Size() ==
|
||||
typename BTileLayout::LayoutShape{}.Size());
|
||||
|
||||
constexpr bool is_integer =
|
||||
is_same_v<DataType, int8_t> || is_same_v<DataType, int16_t> || is_same_v<DataType, int32_t>;
|
||||
using GemmAccDataType = std::conditional_t<is_integer, int32_t, float>;
|
||||
|
||||
constexpr bool is_3d_desc = typename ATileLayout::LayoutShape{}.Size() == I3;
|
||||
using ABlockDesc_K0_M_K1_Type =
|
||||
conditional_t<is_3d_desc,
|
||||
typename ATileLayout::LayoutUnrolledDescriptorType,
|
||||
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>())>;
|
||||
using BBlockDesc_K0_N_K1_Type =
|
||||
conditional_t<is_3d_desc,
|
||||
typename BTileLayout::LayoutUnrolledDescriptorType,
|
||||
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>())>;
|
||||
|
||||
using BlockwiseGemmXdlops =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
DataType,
|
||||
DataType,
|
||||
GemmAccDataType,
|
||||
ABlockDesc_K0_M_K1_Type,
|
||||
BBlockDesc_K0_N_K1_Type,
|
||||
GemmTraits::MPerXDL,
|
||||
GemmTraits::NPerXDL,
|
||||
GemmTraits::MXdlPerWave,
|
||||
GemmTraits::NXdlPerWave,
|
||||
GemmTraits::K1>;
|
||||
// Calcualte descriptor, shape and layout
|
||||
constexpr auto vgpr_desc = BlockwiseGemmXdlops::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
const auto vgpr_shape = make_tuple(vgpr_desc.GetLengths()[I0],
|
||||
vgpr_desc.GetLengths()[I1],
|
||||
vgpr_desc.GetLengths()[I2],
|
||||
vgpr_desc.GetLengths()[I3],
|
||||
vgpr_desc.GetLengths()[I4],
|
||||
vgpr_desc.GetLengths()[I5],
|
||||
vgpr_desc.GetLengths()[I6],
|
||||
vgpr_desc.GetLengths()[I7]);
|
||||
const auto vgpr_layout = Layout<remove_reference_t<decltype(vgpr_shape)>, decltype(vgpr_desc)>(
|
||||
vgpr_shape, vgpr_desc);
|
||||
// Get vector type for Vgpr
|
||||
constexpr index_t ScalarPerVector = BlockwiseGemmXdlops::xdlops_gemm.GetRegSizePerXdlops();
|
||||
using VgprVectorType = typename vector_type<GemmAccDataType, ScalarPerVector>::type;
|
||||
return ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, VgprVectorType>(
|
||||
vgpr_layout);
|
||||
}
|
||||
|
||||
} // namespace wrapper
|
||||
} // namespace ck
|
||||
@@ -7,9 +7,15 @@
|
||||
#include "utils/tensor_partition.hpp"
|
||||
#include "utils/layout_utils.hpp"
|
||||
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond INTERNAL
|
||||
namespace ck {
|
||||
namespace wrapper {
|
||||
/// @endcond
|
||||
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond INTERNAL
|
||||
namespace {
|
||||
namespace detail {
|
||||
namespace {
|
||||
/**
|
||||
@@ -188,7 +194,11 @@ __host__ __device__ constexpr auto GenerateSlicedDescriptor(const Tuple<Ts...>&
|
||||
return transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims);
|
||||
}
|
||||
} // namespace
|
||||
<<<<<<< HEAD
|
||||
} // namespace detail
|
||||
=======
|
||||
/// @endcond
|
||||
>>>>>>> 42fc8eddd (Fix warnings during wrapper docs generation (#1192))
|
||||
|
||||
/**
|
||||
* \brief Tensor wrapper that performs static and dynamic buffer logic.
|
||||
@@ -391,6 +401,8 @@ struct Tensor
|
||||
}
|
||||
|
||||
private:
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond INTERNAL
|
||||
using DynamicBufferType = DynamicBuffer<BufferAddressSpace,
|
||||
ElementType,
|
||||
ElementSpaceSize,
|
||||
@@ -417,6 +429,7 @@ struct Tensor
|
||||
// tensor descriptor (thus all it's transforms) and is linear (1D).
|
||||
// We store base_offset_ to avoid multiple recalculations.
|
||||
index_t base_offset_;
|
||||
/// @endcond
|
||||
};
|
||||
|
||||
} // namespace wrapper
|
||||
|
||||
81
include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp
Normal file
81
include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp
Normal file
@@ -0,0 +1,81 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond INTERNAL
|
||||
namespace ck {
|
||||
namespace wrapper {
|
||||
/// @endcond
|
||||
|
||||
/**
|
||||
* \brief Traits for blockwise gemm xdl.
|
||||
*
|
||||
* \tparam MPerXDLValue The MFMA instruction size in M dimension.
|
||||
* \tparam NPerXDLValue The MFMA instruction size in N dimension.
|
||||
* \tparam MXdlPerWaveValue The number of MFMA instructions run by single
|
||||
* wave in M dimension.
|
||||
* \tparam NXdlPerWaveValue The number of MFMA instructions run by single
|
||||
* wave in N dimension.
|
||||
* \tparam K1Value The number of K-dim elements that are packed together as
|
||||
* a separate logical dimension. Usually aligns with vector load size.
|
||||
*/
|
||||
template <typename MPerXDLValue,
|
||||
typename NPerXDLValue,
|
||||
typename MXdlPerWaveValue,
|
||||
typename NXdlPerWaveValue,
|
||||
typename K1Value>
|
||||
struct BlockwisGemmXdlTraits
|
||||
{
|
||||
static constexpr auto MPerXDL = MPerXDLValue{};
|
||||
static constexpr auto NPerXDL = NPerXDLValue{};
|
||||
static constexpr auto MXdlPerWave = MXdlPerWaveValue{};
|
||||
static constexpr auto NXdlPerWave = NXdlPerWaveValue{};
|
||||
static constexpr auto K1 = K1Value{};
|
||||
};
|
||||
|
||||
// K1 = 4
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1
|
||||
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<4>, Number<2>, Number<4>>
|
||||
{
|
||||
};
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_4K1
|
||||
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<4>, Number<4>>
|
||||
{
|
||||
};
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1
|
||||
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<2>, Number<4>>
|
||||
{
|
||||
};
|
||||
// K1 = 8
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_8K1
|
||||
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<4>, Number<2>, Number<8>>
|
||||
{
|
||||
};
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_8K1
|
||||
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<4>, Number<8>>
|
||||
{
|
||||
};
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1
|
||||
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<2>, Number<8>>
|
||||
{
|
||||
};
|
||||
// K1 = 16
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_16K1
|
||||
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<4>, Number<2>, Number<16>>
|
||||
{
|
||||
};
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_16K1
|
||||
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<4>, Number<16>>
|
||||
{
|
||||
};
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1
|
||||
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<2>, Number<16>>
|
||||
{
|
||||
};
|
||||
|
||||
} // namespace wrapper
|
||||
} // namespace ck
|
||||
17
include/ck/wrapper/utils/kernel_utils.hpp
Normal file
17
include/ck/wrapper/utils/kernel_utils.hpp
Normal file
@@ -0,0 +1,17 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond INTERNAL
|
||||
namespace ck {
|
||||
namespace wrapper {
|
||||
/// @endcond
|
||||
|
||||
#define __CK_WRAPPER_LAUNCH_BOUNDS__ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
|
||||
} // namespace wrapper
|
||||
} // namespace ck
|
||||
@@ -16,11 +16,14 @@
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_description/multi_index_transform_helper.hpp"
|
||||
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond INTERNAL
|
||||
namespace ck {
|
||||
namespace wrapper {
|
||||
/// @endcond
|
||||
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond
|
||||
/// @cond INTERNAL
|
||||
// forward declaration
|
||||
template <typename Shape, typename UnrolledDescriptorType>
|
||||
struct Layout;
|
||||
|
||||
@@ -9,9 +9,14 @@
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "ck/tensor_description/cluster_descriptor.hpp"
|
||||
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond INTERNAL
|
||||
namespace ck {
|
||||
namespace wrapper {
|
||||
/// @endcond
|
||||
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond INTERNAL
|
||||
namespace {
|
||||
|
||||
/**
|
||||
@@ -70,6 +75,7 @@ CalculateOffsetMultiIdxs(const ThreadIdxs& thread_idxs,
|
||||
}
|
||||
|
||||
} // namespace
|
||||
/// @endcond
|
||||
|
||||
/**
|
||||
* \brief Create local partition for thread (At now only packed partition
|
||||
|
||||
@@ -12,8 +12,11 @@
|
||||
#include "ck/utility/amd_address_space.hpp"
|
||||
#include "ck/utility/multi_index.hpp"
|
||||
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond INTERNAL
|
||||
namespace ck {
|
||||
namespace wrapper {
|
||||
/// @endcond
|
||||
|
||||
/**
|
||||
* \brief Memory type, allowed members:
|
||||
@@ -26,7 +29,7 @@ namespace wrapper {
|
||||
using MemoryTypeEnum = AddressSpaceEnum;
|
||||
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond
|
||||
/// @cond INTERNAL
|
||||
// forward declarations
|
||||
template <typename Shape, typename UnrolledDescriptorType>
|
||||
struct Layout;
|
||||
|
||||
Reference in New Issue
Block a user