mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 03:19:48 +00:00
Add blockwise gemm to ck wrapper (#1139)
* Add blockwise gemm to ck wrapper
* Add blockwise gemm traits
* Disable test_gemm for non xdl devices
* Fixes
* Add c layout descritpions
[ROCm/composable_kernel commit: f3b6c23ac5]
This commit is contained in:
@@ -248,6 +248,9 @@ struct Layout
|
||||
using DefaultIdxsTupleType = remove_cvref_t<decltype(GenerateDefaultIdxsTuple(Shape{}))>;
|
||||
|
||||
public:
|
||||
using LayoutShape = Shape;
|
||||
using LayoutUnrolledDescriptorType = UnrolledDescriptorType;
|
||||
|
||||
/**
|
||||
* \brief Transform descriptor to align to passed indexes.
|
||||
*
|
||||
|
||||
@@ -3,45 +3,18 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "../utils/tensor_utils.hpp"
|
||||
#include "ck/wrapper/utils/tensor_utils.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace wrapper {
|
||||
|
||||
/**
|
||||
* \brief Perform generic copy between two tensors partitions (threadwise copy).
|
||||
* Tensors must have the same size.
|
||||
*
|
||||
* \param src_tensor Source tensor.
|
||||
* \param dst_tensor Destination tensor.
|
||||
*/
|
||||
template <typename SrcTensorType, typename DstTensorType>
|
||||
__host__ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor)
|
||||
{
|
||||
if constexpr(!SrcTensorType::IsDynamicBuffer)
|
||||
{
|
||||
using SizeType = decltype(size(src_tensor));
|
||||
static_for<0, SizeType{}, 1>{}([&](auto i) { dst_tensor(i) = src_tensor(i); });
|
||||
}
|
||||
else if constexpr(!DstTensorType::IsDynamicBuffer)
|
||||
{
|
||||
using SizeType = decltype(size(dst_tensor));
|
||||
static_for<0, SizeType{}, 1>{}([&](auto i) { dst_tensor(i) = src_tensor(i); });
|
||||
}
|
||||
else
|
||||
{
|
||||
for(int i = 0; i < size(src_tensor); i++)
|
||||
{
|
||||
dst_tensor(i) = src_tensor(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Perform optimized copy between two tensors partitions (threadwise copy).
|
||||
* Tensors must have the same size.
|
||||
@@ -167,9 +140,99 @@ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor)
|
||||
else
|
||||
{
|
||||
// Perform copy between StaticBuffers
|
||||
copy(src_tensor, dst_tensor);
|
||||
static_for<0, SrcShapeType::Size(), 1>{}([&](auto i) { dst_tensor(i) = src_tensor(i); });
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Perform generic copy between two tensors partitions (threadwise copy).
|
||||
* Tensors must have the same size.
|
||||
*
|
||||
* \param src_tensor Source tensor.
|
||||
* \param dst_tensor Destination tensor.
|
||||
*/
|
||||
template <typename SrcTensorType, typename DstTensorType>
|
||||
__host__ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor)
|
||||
{
|
||||
// Generate default params
|
||||
using SrcShapeType = remove_cvref_t<decltype(shape(src_tensor))>;
|
||||
constexpr index_t num_dims = SrcShapeType::Size();
|
||||
// Incrementing dims 0, 1, 2 ... num_dims - 1
|
||||
constexpr auto dim_access_order_tuple =
|
||||
generate_tuple([](auto i) { return Number<i>{}; }, Number<num_dims>{});
|
||||
constexpr index_t vector_dim = num_dims - 1;
|
||||
constexpr index_t scalar_per_vector = 1;
|
||||
copy<decltype(dim_access_order_tuple), vector_dim, scalar_per_vector>(src_tensor, dst_tensor);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Perform optimized blockwise copy between two tensors. Tensors must have the
|
||||
* same size.
|
||||
*
|
||||
* \note At now Vgpr and Sgpr are not supported.
|
||||
*
|
||||
* \tparam DimAccessOrderTuple Tuple with dimension access order.
|
||||
* \tparam VectorDim Dimension for vectorize read and write.
|
||||
* \tparam ScalarPerVector Number of scalar per vectorize read and write.
|
||||
* \param src_tensor Source tensor.
|
||||
* \param dst_tensor Destination tensor.
|
||||
* \param thread_layout Thread layout per each dimension for copy.
|
||||
*/
|
||||
template <typename DimAccessOrderTuple,
|
||||
index_t VectorDim,
|
||||
index_t ScalarPerVector,
|
||||
typename SrcTensorType,
|
||||
typename DstTensorType,
|
||||
typename ThreadLayoutTuple>
|
||||
__device__ void blockwise_copy(const SrcTensorType& src_tensor,
|
||||
DstTensorType& dst_tensor,
|
||||
[[maybe_unused]] ThreadLayoutTuple& thread_layout)
|
||||
{
|
||||
static_assert(SrcTensorType::IsDynamicBuffer && DstTensorType::IsDynamicBuffer);
|
||||
static_assert(is_detected<is_tuple, DimAccessOrderTuple>::value);
|
||||
|
||||
const auto& in_grid_desc = layout(src_tensor).GetUnrolledDescriptor();
|
||||
const auto& out_grid_desc = layout(dst_tensor).GetUnrolledDescriptor();
|
||||
|
||||
using SrcShapeType = remove_cvref_t<decltype(shape(src_tensor))>;
|
||||
constexpr index_t num_dims = SrcShapeType::Size();
|
||||
|
||||
constexpr auto tile_lengths_seq =
|
||||
generate_sequence_v2([](auto I) { return size(SrcShapeType{}.At(I)); }, Number<num_dims>{});
|
||||
constexpr auto thread_layout_seq = generate_sequence_v2(
|
||||
[](auto I) { return size(ThreadLayoutTuple{}.At(I)); }, Number<num_dims>{});
|
||||
constexpr auto dim_access_order = generate_sequence_v2(
|
||||
[](auto I) { return DimAccessOrderTuple{}.At(I); }, Number<num_dims>{});
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<size(ThreadLayoutTuple{})>;
|
||||
|
||||
// Perform copy between DynamicBuffers
|
||||
auto transfer = ThreadGroupTensorSliceTransfer_v7<
|
||||
ThisThreadBlock,
|
||||
Tuple<typename SrcTensorType::TensorElementType>,
|
||||
Tuple<typename DstTensorType::TensorElementType>,
|
||||
decltype(tie(in_grid_desc)),
|
||||
decltype(tie(out_grid_desc)),
|
||||
tensor_operation::element_wise::PassThrough,
|
||||
Sequence<static_cast<index_t>(InMemoryDataOperationEnum::Set)>,
|
||||
std::remove_const_t<decltype(tile_lengths_seq)>,
|
||||
std::remove_const_t<decltype(thread_layout_seq)>,
|
||||
std::remove_const_t<decltype(dim_access_order)>,
|
||||
std::remove_const_t<decltype(dim_access_order)>,
|
||||
VectorDim,
|
||||
ScalarPerVector,
|
||||
Sequence<true>,
|
||||
Sequence<true>>{in_grid_desc,
|
||||
make_tuple(src_tensor.GetMultiIdxOffsets()),
|
||||
out_grid_desc,
|
||||
make_tuple(dst_tensor.GetMultiIdxOffsets()),
|
||||
tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
transfer.Run(tie(in_grid_desc),
|
||||
tie(src_tensor.GetBuffer()),
|
||||
tie(out_grid_desc),
|
||||
tie(dst_tensor.GetBuffer()));
|
||||
}
|
||||
|
||||
} // namespace wrapper
|
||||
} // namespace ck
|
||||
|
||||
337
include/ck/wrapper/operations/gemm.hpp
Normal file
337
include/ck/wrapper/operations/gemm.hpp
Normal file
@@ -0,0 +1,337 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/wrapper/utils/tensor_utils.hpp"
|
||||
#include "ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp"
|
||||
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace wrapper {
|
||||
|
||||
namespace {
|
||||
namespace detail {
|
||||
/**
|
||||
* \brief Create block descriptor (K0, MPerBlock or NPerBlock, K1).
|
||||
*
|
||||
*
|
||||
* \tparam K1 The number of K-dim elements that are packed together as a separate logical dimension.
|
||||
* \tparam TileLayout Tensor data tile layout (M,K) or (N,K).
|
||||
*
|
||||
* \return Block descriptor (K0, MPerBlock or NPerBlock, K1)
|
||||
*/
|
||||
template <index_t K1, typename TileLayout>
|
||||
__device__ constexpr auto GetBlockDescriptor()
|
||||
{
|
||||
using TileLayoutShape = typename TileLayout::LayoutShape;
|
||||
using TileLayoutDescriptor = typename TileLayout::LayoutUnrolledDescriptorType;
|
||||
|
||||
constexpr auto K0PerBlock = Number<size<1>(TileLayoutShape{})>{} / Number<K1>{};
|
||||
// MPerBlock or NPerBlock
|
||||
constexpr auto Dim0 = Number<size<0>(TileLayoutShape{})>{};
|
||||
|
||||
constexpr auto a_block_desc_k0_m_k1 = transform_tensor_descriptor(
|
||||
TileLayoutDescriptor{},
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0PerBlock, Number<K1>{})),
|
||||
make_pass_through_transform(Dim0)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return a_block_desc_k0_m_k1;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace
|
||||
|
||||
/**
|
||||
* \brief Perform blockwise gemm xdl on tensors stored in lds. Result will be
|
||||
* stored in Vgpr register. A data layout must be (MPerBlock, KPerBlock) and B
|
||||
* data layout must be (NPerBlock, KPerBlock).
|
||||
*
|
||||
* \note C output Vgpr register layout (8D):
|
||||
* - MXdlPerWave - The number of MFMA instructions run by single wave in M
|
||||
* dimension per tile.
|
||||
* - NXdlPerWave - The number of MFMA instructions run by single wave in N
|
||||
* dimension per tile.
|
||||
* - MWave - Equals to 1 since this is for single wave.
|
||||
* - NWave - Equals to 1 since this is for single wave.
|
||||
* - NumGroupsPerBlock - Mfma instruction internal layout (depeneds on the
|
||||
* instruction size).
|
||||
* - NumInputsBlock - Mfma instruction internal layout (depeneds on the
|
||||
* instruction size).
|
||||
* - GroupSize - Mfma instruction internal layout (depeneds on the
|
||||
* instruction size).
|
||||
* - NumThreadsPerBlock - Mfma instruction internal layout (depeneds on the
|
||||
* instruction size).
|
||||
*
|
||||
* \tparam DataType Input data types.
|
||||
* \tparam BlockSize Tensor to pad.
|
||||
* \tparam GemmTraits Traits of gemm xdl operation.
|
||||
* \param a_local_tile_tensor A tensor in LDS memory for blockwise gemm
|
||||
* (MPerBlock, KPerBlock) layout.
|
||||
* \param b_local_tile_tensor B tensor in LDS memory for blockwise gemm
|
||||
* (NPerBlock, KPerBlock) layout.
|
||||
* \param c_reg_tensor C tensor VGPR memory for blockwise gemm.
|
||||
*/
|
||||
template <typename DataType,
|
||||
index_t BlockSize,
|
||||
typename GemmTraits,
|
||||
typename ATensorType,
|
||||
typename BTensorType,
|
||||
typename CTensorType>
|
||||
__device__ void blockwise_gemm_xdl(const ATensorType& a_local_tile_tensor,
|
||||
const BTensorType& b_local_tile_tensor,
|
||||
CTensorType& c_reg_tensor)
|
||||
{
|
||||
static_assert(ATensorType::TensorBufferAddressSpace == MemoryTypeEnum::Lds);
|
||||
static_assert(BTensorType::TensorBufferAddressSpace == MemoryTypeEnum::Lds);
|
||||
static_assert(CTensorType::TensorBufferAddressSpace == MemoryTypeEnum::Vgpr);
|
||||
static_assert(is_same_v<DataType, typename ATensorType::TensorElementType>);
|
||||
static_assert(is_same_v<DataType, typename BTensorType::TensorElementType>);
|
||||
|
||||
constexpr bool is_integer =
|
||||
is_same_v<DataType, int8_t> || is_same_v<DataType, int16_t> || is_same_v<DataType, int32_t>;
|
||||
using GemmAccDataType = std::conditional_t<is_integer, int32_t, float>;
|
||||
|
||||
using ATileLayout = remove_cvref_t<decltype(layout(a_local_tile_tensor))>;
|
||||
using BTileLayout = remove_cvref_t<decltype(layout(b_local_tile_tensor))>;
|
||||
|
||||
using ABlockDesc_K0_M_K1_Type =
|
||||
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>());
|
||||
using BBlockDesc_K0_N_K1_Type =
|
||||
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>());
|
||||
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
DataType,
|
||||
DataType,
|
||||
GemmAccDataType,
|
||||
ABlockDesc_K0_M_K1_Type,
|
||||
BBlockDesc_K0_N_K1_Type,
|
||||
GemmTraits::MPerXDL,
|
||||
GemmTraits::NPerXDL,
|
||||
GemmTraits::MXdlPerWave,
|
||||
GemmTraits::NXdlPerWave,
|
||||
GemmTraits::K1>
|
||||
blockwise_gemm_xdl_op{};
|
||||
|
||||
blockwise_gemm_xdl_op.Run(
|
||||
a_local_tile_tensor.GetBuffer(), b_local_tile_tensor.GetBuffer(), c_reg_tensor.GetBuffer());
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Create local partition per thread for C tensor.
|
||||
*
|
||||
* \note C output global memory layout (8D):
|
||||
* - MXdlPerWave - The number of MFMA instructions run by single wave in M
|
||||
* dimension.
|
||||
* - NXdlPerWave - The number of MFMA instructions run by single wave in N
|
||||
* dimension.
|
||||
* - MWave - The number of waves in single tile M dimension per tile.
|
||||
* - NWave - The number of waves in single tile N dimension per tile.
|
||||
* - NumGroupsPerBlock - Mfma instruction internal layout (depeneds on the
|
||||
* instruction size).
|
||||
* - NumInputsBlock - Mfma instruction internal layout (depeneds on the
|
||||
* instruction size).
|
||||
* - GroupSize - Mfma instruction internal layout (depeneds on the
|
||||
* instruction size).
|
||||
* - NumThreadsPerBlock - Mfma instruction internal layout (depeneds on the
|
||||
* instruction size).
|
||||
*
|
||||
* \tparam DataType Input data types.
|
||||
* \tparam ATileLayout A tensor layout.
|
||||
* \tparam BTileLayout B tensor layout.
|
||||
* \tparam BlockSize Number of threads in block.
|
||||
* \tparam GemmTraits Traits of gemm xdl operation.
|
||||
* \param c_local_tile_tensor C tensor in LDS memory for blockwise gemm
|
||||
* (MPerBlock, NPerBlock) layout.
|
||||
*
|
||||
* \return Partition c tensor for blockwise gemm.
|
||||
*/
|
||||
template <typename DataType,
|
||||
typename ATileLayout,
|
||||
typename BTileLayout,
|
||||
index_t BlockSize,
|
||||
typename GemmTraits,
|
||||
typename CTensorType>
|
||||
__host__ __device__ constexpr auto
|
||||
make_blockwise_gemm_xdl_c_local_partition(CTensorType& c_local_tile_tensor)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
|
||||
constexpr bool is_integer =
|
||||
is_same_v<DataType, int8_t> || is_same_v<DataType, int16_t> || is_same_v<DataType, int32_t>;
|
||||
using GemmAccDataType = std::conditional_t<is_integer, int32_t, float>;
|
||||
|
||||
using ABlockDesc_K0_M_K1_Type =
|
||||
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>());
|
||||
using BBlockDesc_K0_N_K1_Type =
|
||||
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>());
|
||||
|
||||
using BlockwiseGemmXdlops =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
DataType,
|
||||
DataType,
|
||||
GemmAccDataType,
|
||||
ABlockDesc_K0_M_K1_Type,
|
||||
BBlockDesc_K0_N_K1_Type,
|
||||
GemmTraits::MPerXDL,
|
||||
GemmTraits::NPerXDL,
|
||||
GemmTraits::MXdlPerWave,
|
||||
GemmTraits::NXdlPerWave,
|
||||
GemmTraits::K1>;
|
||||
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
|
||||
BlockwiseGemmXdlops::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0);
|
||||
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1);
|
||||
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2);
|
||||
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3);
|
||||
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4);
|
||||
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5);
|
||||
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6);
|
||||
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7);
|
||||
|
||||
// Calculate offset on grid
|
||||
const auto c_thread_mtx_on_block =
|
||||
BlockwiseGemmXdlops::CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
|
||||
|
||||
const index_t m_thread_data_on_grid =
|
||||
c_local_tile_tensor.GetMultiIdxOffsets()[I0] + c_thread_mtx_on_block[I0];
|
||||
|
||||
const index_t n_thread_data_on_grid =
|
||||
c_local_tile_tensor.GetMultiIdxOffsets()[I1] + c_thread_mtx_on_block[I1];
|
||||
|
||||
const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
|
||||
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto m_thread_data_on_grid_idx =
|
||||
m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(m_thread_data_on_grid));
|
||||
|
||||
const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor =
|
||||
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto n_thread_data_on_grid_idx =
|
||||
n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(n_thread_data_on_grid));
|
||||
// Create partition shape based on descriptor dims.
|
||||
const auto partition_shape = make_tuple(M0, N0, I1, I1, M2, I1, M4, I1);
|
||||
|
||||
const auto partition_desc = BlockwiseGemmXdlops::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(
|
||||
layout(c_local_tile_tensor).GetUnrolledDescriptor());
|
||||
const auto partition_layout =
|
||||
Layout<remove_reference_t<decltype(partition_shape)>, decltype(partition_desc)>(
|
||||
partition_shape, partition_desc);
|
||||
auto partition_tensor = make_tensor<CTensorType::TensorBufferAddressSpace>(
|
||||
c_local_tile_tensor.GetPointer(), partition_layout);
|
||||
partition_tensor.SetMultiIdxOffset(make_multi_index(m_thread_data_on_grid_idx[I0],
|
||||
n_thread_data_on_grid_idx[I0],
|
||||
m_thread_data_on_grid_idx[I1],
|
||||
n_thread_data_on_grid_idx[I1],
|
||||
m_thread_data_on_grid_idx[I2],
|
||||
m_thread_data_on_grid_idx[I3],
|
||||
m_thread_data_on_grid_idx[I4],
|
||||
n_thread_data_on_grid_idx[I2]));
|
||||
return partition_tensor;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Create local partition per thread for C tensor.
|
||||
*
|
||||
* \note C output Vgpr register layout (8D):
|
||||
* - MXdlPerWave - The number of MFMA instructions run by single wave in M
|
||||
* dimension per tile.
|
||||
* - NXdlPerWave - The number of MFMA instructions run by single wave in N
|
||||
* dimension per tile.
|
||||
* - MWave - Equals to 1 since this is for single wave.
|
||||
* - NWave - Equals to 1 since this is for single wave.
|
||||
* - NumGroupsPerBlock - Mfma instruction internal layout (depeneds on the
|
||||
* instruction size).
|
||||
* - NumInputsBlock - Mfma instruction internal layout (depeneds on the
|
||||
* instruction size).
|
||||
* - GroupSize - Mfma instruction internal layout (depeneds on the
|
||||
* instruction size).
|
||||
* - NumThreadsPerBlock - Mfma instruction internal layout (depeneds on the
|
||||
* instruction size).
|
||||
*
|
||||
* \tparam DataType Input data types.
|
||||
* \tparam ATileLayout A tensor layout.
|
||||
* \tparam BTileLayout B tensor layout.
|
||||
* \tparam BlockSize Number of threads in block.
|
||||
* \tparam GemmTraits Traits of gemm xdl operation.
|
||||
*
|
||||
* \return Vgpr c tensor for blockwise gemm.
|
||||
*/
|
||||
template <typename DataType,
|
||||
typename ATileLayout,
|
||||
typename BTileLayout,
|
||||
index_t BlockSize,
|
||||
typename GemmTraits>
|
||||
__host__ __device__ constexpr auto make_blockwise_gemm_xdl_c_vgpr()
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
|
||||
constexpr bool is_integer =
|
||||
is_same_v<DataType, int8_t> || is_same_v<DataType, int16_t> || is_same_v<DataType, int32_t>;
|
||||
using GemmAccDataType = std::conditional_t<is_integer, int32_t, float>;
|
||||
|
||||
using ABlockDesc_K0_M_K1_Type =
|
||||
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>());
|
||||
using BBlockDesc_K0_N_K1_Type =
|
||||
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>());
|
||||
|
||||
using BlockwiseGemmXdlops =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
DataType,
|
||||
DataType,
|
||||
GemmAccDataType,
|
||||
ABlockDesc_K0_M_K1_Type,
|
||||
BBlockDesc_K0_N_K1_Type,
|
||||
GemmTraits::MPerXDL,
|
||||
GemmTraits::NPerXDL,
|
||||
GemmTraits::MXdlPerWave,
|
||||
GemmTraits::NXdlPerWave,
|
||||
GemmTraits::K1>;
|
||||
// Calcualte descriptor, shape and layout
|
||||
constexpr auto vgpr_desc = BlockwiseGemmXdlops::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
const auto vgpr_shape = make_tuple(vgpr_desc.GetLengths()[I0],
|
||||
vgpr_desc.GetLengths()[I1],
|
||||
vgpr_desc.GetLengths()[I2],
|
||||
vgpr_desc.GetLengths()[I3],
|
||||
vgpr_desc.GetLengths()[I4],
|
||||
vgpr_desc.GetLengths()[I5],
|
||||
vgpr_desc.GetLengths()[I6],
|
||||
vgpr_desc.GetLengths()[I7]);
|
||||
const auto vgpr_layout = Layout<remove_reference_t<decltype(vgpr_shape)>, decltype(vgpr_desc)>(
|
||||
vgpr_shape, vgpr_desc);
|
||||
// Get vector type for Vgpr
|
||||
using BlockwiseGemmCThreadBufferType =
|
||||
remove_reference_t<decltype(BlockwiseGemmXdlops{}.GetCThreadBuffer())>;
|
||||
using VgprVectorType = typename BlockwiseGemmCThreadBufferType::V;
|
||||
return ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, VgprVectorType>(
|
||||
vgpr_layout);
|
||||
}
|
||||
|
||||
} // namespace wrapper
|
||||
} // namespace ck
|
||||
@@ -10,8 +10,8 @@
|
||||
namespace ck {
|
||||
namespace wrapper {
|
||||
|
||||
namespace detail {
|
||||
namespace {
|
||||
namespace detail {
|
||||
/**
|
||||
* \brief Check if Tuple contains Slice object
|
||||
*
|
||||
@@ -187,8 +187,8 @@ __host__ __device__ constexpr auto GenerateSlicedDescriptor(const Tuple<Ts...>&
|
||||
const auto upper_dims = decltype(GenerateUpperDims<0>(TransformsTupleType{})){};
|
||||
return transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims);
|
||||
}
|
||||
} // namespace
|
||||
} // namespace detail
|
||||
} // namespace
|
||||
|
||||
/**
|
||||
* \brief Tensor wrapper that performs static and dynamic buffer logic.
|
||||
@@ -209,7 +209,10 @@ struct Tensor
|
||||
public:
|
||||
using ElementSpaceSize = decltype(Layout<Shape, UnrolledDescriptorType>{
|
||||
Shape{}, UnrolledDescriptorType{}}.GetElementSpaceSize()); // SpaceSize type for buffer
|
||||
using TensorElementType = ElementType; // DataType
|
||||
using TensorElementType = std::conditional_t<
|
||||
is_scalar_type<ElementType>::value,
|
||||
ElementType,
|
||||
typename scalar_type<std::remove_const_t<ElementType>>::type>; // DataType
|
||||
|
||||
static constexpr MemoryTypeEnum TensorBufferAddressSpace = BufferAddressSpace;
|
||||
static constexpr bool IsDynamicBuffer = !(BufferAddressSpace == MemoryTypeEnum ::Sgpr ||
|
||||
@@ -280,7 +283,7 @@ struct Tensor
|
||||
* \return Requested value.
|
||||
*/
|
||||
template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
|
||||
__host__ __device__ const ElementType& operator[](const Tuple<Ts...>& idx) const
|
||||
__host__ __device__ const TensorElementType& operator[](const Tuple<Ts...>& idx) const
|
||||
{
|
||||
if constexpr(IsDynamicBuffer)
|
||||
{
|
||||
@@ -301,13 +304,13 @@ struct Tensor
|
||||
}
|
||||
|
||||
template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
|
||||
__host__ __device__ const ElementType& operator()(const Tuple<Ts...>& idx) const
|
||||
__host__ __device__ const TensorElementType& operator()(const Tuple<Ts...>& idx) const
|
||||
{
|
||||
return this->operator[](idx);
|
||||
}
|
||||
|
||||
template <typename... Idxs, enable_if_t<!detail::HasSlice(Tuple<Idxs...>{}), bool> = false>
|
||||
__host__ __device__ const ElementType& operator()(Idxs... idxs) const
|
||||
__host__ __device__ const TensorElementType& operator()(Idxs... idxs) const
|
||||
{
|
||||
return this->operator[](make_tuple(idxs...));
|
||||
}
|
||||
@@ -319,7 +322,7 @@ struct Tensor
|
||||
* \return Requested value.
|
||||
*/
|
||||
template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
|
||||
__host__ __device__ ElementType& operator[](const Tuple<Ts...>& idx)
|
||||
__host__ __device__ TensorElementType& operator[](const Tuple<Ts...>& idx)
|
||||
{
|
||||
if constexpr(IsDynamicBuffer)
|
||||
{
|
||||
@@ -340,13 +343,13 @@ struct Tensor
|
||||
}
|
||||
|
||||
template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
|
||||
__host__ __device__ ElementType& operator()(const Tuple<Ts...>& idx)
|
||||
__host__ __device__ TensorElementType& operator()(const Tuple<Ts...>& idx)
|
||||
{
|
||||
return this->operator[](idx);
|
||||
}
|
||||
|
||||
template <typename... Idxs, enable_if_t<!detail::HasSlice(Tuple<Idxs...>{}), bool> = false>
|
||||
__host__ __device__ ElementType& operator()(Idxs... idxs)
|
||||
__host__ __device__ TensorElementType& operator()(Idxs... idxs)
|
||||
{
|
||||
return this->operator[](make_tuple(idxs...));
|
||||
}
|
||||
@@ -366,7 +369,7 @@ struct Tensor
|
||||
*
|
||||
* \return Pointer.
|
||||
*/
|
||||
__host__ __device__ ElementType* GetPointer() const { return buffer_.p_data_; }
|
||||
__host__ __device__ TensorElementType* GetPointer() const { return buffer_.p_data_; }
|
||||
|
||||
__host__ __device__ constexpr auto& GetBuffer() { return buffer_; }
|
||||
__host__ __device__ constexpr auto& GetBuffer() const { return buffer_; }
|
||||
@@ -395,10 +398,18 @@ struct Tensor
|
||||
ElementType,
|
||||
ElementSpaceSize,
|
||||
true /*InvalidElementUseNumericalZeroValue*/>;
|
||||
using StaticBufferType = StaticBuffer<BufferAddressSpace,
|
||||
ElementType,
|
||||
size(Shape{}),
|
||||
true /*InvalidElementUseNumericalZeroValue*/>;
|
||||
using StaticBufferType = std::conditional_t<
|
||||
is_scalar_type<ElementType>::value,
|
||||
StaticBuffer<BufferAddressSpace,
|
||||
ElementType,
|
||||
size(Shape{}),
|
||||
true /*InvalidElementUseNumericalZeroValue*/>,
|
||||
StaticBufferTupleOfVector<BufferAddressSpace,
|
||||
TensorElementType,
|
||||
size(Shape{}) /
|
||||
scalar_type<std::remove_const_t<ElementType>>::vector_size,
|
||||
scalar_type<std::remove_const_t<ElementType>>::vector_size,
|
||||
true /*InvalidElementUseNumericalZeroValue*/>>;
|
||||
// If register use static buffer, else use dynamic buffer
|
||||
using Buffer = std::conditional_t<IsDynamicBuffer, DynamicBufferType, StaticBufferType>;
|
||||
|
||||
|
||||
48
include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp
Normal file
48
include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp
Normal file
@@ -0,0 +1,48 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace wrapper {
|
||||
|
||||
/**
|
||||
* \brief Traits for blockwise gemm xdl.
|
||||
*
|
||||
* \tparam MPerXDLValue The MFMA instruction size in M dimension.
|
||||
* \tparam NPerXDLValue The MFMA instruction size in N dimension.
|
||||
* \tparam MXdlPerWaveValue The number of MFMA instructions run by single
|
||||
* wave in M dimension.
|
||||
* \tparam NXdlPerWaveValue The number of MFMA instructions run by single
|
||||
* wave in N dimension.
|
||||
* \tparam K1Value The number of K-dim elements that are packed together as
|
||||
* a separate logical dimension. Usually aligns with vector load size.
|
||||
*/
|
||||
template <index_t MPerXDLValue,
|
||||
index_t NPerXDLValue,
|
||||
index_t MXdlPerWaveValue,
|
||||
index_t NXdlPerWaveValue,
|
||||
index_t K1Value>
|
||||
struct BlockwisGemmXdlTraits
|
||||
{
|
||||
static constexpr index_t MPerXDL = MPerXDLValue;
|
||||
static constexpr index_t NPerXDL = NPerXDLValue;
|
||||
static constexpr index_t MXdlPerWave = MXdlPerWaveValue;
|
||||
static constexpr index_t NXdlPerWave = NXdlPerWaveValue;
|
||||
static constexpr index_t K1 = K1Value;
|
||||
};
|
||||
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 4>
|
||||
{
|
||||
};
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 2, 4, 4>
|
||||
{
|
||||
};
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 4>
|
||||
{
|
||||
};
|
||||
|
||||
} // namespace wrapper
|
||||
} // namespace ck
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "tensor_utils.hpp"
|
||||
#include "layout_utils.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "ck/tensor_description/cluster_descriptor.hpp"
|
||||
|
||||
@@ -14,6 +15,8 @@ namespace wrapper {
|
||||
|
||||
namespace {
|
||||
|
||||
namespace detail {
|
||||
|
||||
/**
|
||||
* \brief Calculate shape for partition based on number of threads per each dim and
|
||||
* previous shape
|
||||
@@ -30,26 +33,109 @@ __host__ __device__ constexpr auto CalculateLocalPartitionShape(const Tuple<Ts..
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr auto num_i = Number<i>{};
|
||||
const auto slice_len = size<num_i>(shape) / thread_lengths.At(num_i);
|
||||
const auto slice_len =
|
||||
ck::math::integer_divide_ceil(size<num_i>(shape), thread_lengths.At(num_i));
|
||||
return slice_len;
|
||||
},
|
||||
Number<Tuple<Ls...>::Size()>{});
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Apply projection.
|
||||
*
|
||||
* \param base_tuple Tuple to apply projection.
|
||||
* \param projection Projection to remove selected dim from partitioning.
|
||||
* slice(X) to remove, where X is dim size, Number<1>{} to keep.
|
||||
* \return Multi index after projection.
|
||||
*/
|
||||
template <typename MultiIndex, typename ProjectionTuple>
|
||||
__host__ __device__ constexpr auto
|
||||
ApplyProjection([[maybe_unused]] const MultiIndex& base_tuple,
|
||||
[[maybe_unused]] const ProjectionTuple& projection)
|
||||
{
|
||||
if constexpr(is_same_v<ProjectionTuple, Tuple<>>)
|
||||
{
|
||||
return Tuple<>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
auto base_tuple_after_projection = generate_tuple(
|
||||
[&](auto i) {
|
||||
const auto i_num = Number<i.value>{};
|
||||
static_assert(
|
||||
is_detected<is_slice, tuple_element_t<i_num, ProjectionTuple>>::value ||
|
||||
is_same_v<tuple_element_t<i_num, ProjectionTuple>, Number<1>>);
|
||||
if constexpr(is_detected<is_slice, tuple_element_t<i_num, ProjectionTuple>>::value)
|
||||
{
|
||||
// When slice (to remove), then insert empty tuple (will be removed in next
|
||||
// step).
|
||||
return Tuple<>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return base_tuple.At(i_num);
|
||||
}
|
||||
},
|
||||
Number<MultiIndex::Size()>{});
|
||||
// Remove empty tuples
|
||||
return UnrollNestedTuple<0, 1>(base_tuple_after_projection);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Calculate shape with dims from projection.
|
||||
*
|
||||
* \param shape Base tensor shape.
|
||||
* \param projection Projection to remove selected dim from partitioning.
|
||||
* slice(X) to remove, where X is dim size, Number<1>{} to keep.
|
||||
* \return Shape with dims from projection
|
||||
*/
|
||||
template <typename... Ts, typename... Ps>
|
||||
__host__ __device__ constexpr auto CalculateShapeWithProjection(const Tuple<Ts...>& shape,
|
||||
const Tuple<Ps...>& projection)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(is_detected<is_slice, tuple_element_t<i, Tuple<Ps...>>>::value)
|
||||
{
|
||||
return size<i>(projection).to_;
|
||||
}
|
||||
else
|
||||
{
|
||||
// number of shape element in actual fragment of shape and projection (method to
|
||||
// calculate shape idx)
|
||||
constexpr index_t shape_i =
|
||||
detail::ApplyProjection(TupleSlice<0, i>(Tuple<Ts...>{}),
|
||||
TupleSlice<0, i>(Tuple<Ps...>{}))
|
||||
.Size();
|
||||
return size<shape_i>(shape);
|
||||
}
|
||||
},
|
||||
Number<Tuple<Ps...>::Size()>{});
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Calculate total number of blocks.
|
||||
*
|
||||
* \param shape Base tensor shape.
|
||||
* \param tile_shape Tile shape.
|
||||
* \param projection Projection is used to remove selected dim from
|
||||
* partitioning. Use `slice(X)` to remove dimension, where X is dim
|
||||
* size. Use `Number<1>{}` to keep it.
|
||||
* \return Tuple with blocks number.
|
||||
*/
|
||||
template <typename... Ts, typename... Ls>
|
||||
template <typename... Ts, typename... Ls, typename... Ps>
|
||||
__host__ __device__ constexpr auto CalculateGridSize(const Tuple<Ts...>& shape,
|
||||
const Tuple<Ls...>& tile_shape)
|
||||
const Tuple<Ls...>& tile_shape,
|
||||
const Tuple<Ps...>& projection)
|
||||
{
|
||||
static_assert(Tuple<Ts...>::Size() == Tuple<Ls...>::Size(), "Wrong thread_lengths shape.");
|
||||
return generate_tuple([&](auto i) { return size<i>(shape) / size<i>(tile_shape); },
|
||||
Number<Tuple<Ls...>::Size()>{});
|
||||
auto shape_with_projection = CalculateShapeWithProjection(shape, projection);
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
return ck::math::integer_divide_ceil(size<i>(shape_with_projection),
|
||||
size<i>(tile_shape));
|
||||
},
|
||||
Number<Tuple<Ls...>::Size()>{});
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -69,8 +155,75 @@ CalculateOffsetMultiIdxs(const ThreadIdxs& thread_idxs,
|
||||
return thread_idxs * partition_lengths_seq + old_offset_idxs;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Calculate default projection.
|
||||
*
|
||||
* \param tile_shape Tile shape.
|
||||
* \return Default projection (filled with Number<1>{}).
|
||||
*/
|
||||
template <typename TileShape>
|
||||
__host__ __device__ constexpr auto
|
||||
GenerateDefaultProjection([[maybe_unused]] const TileShape tile_shape)
|
||||
{
|
||||
return generate_tuple([&](auto) { return Number<1>{}; }, Number<TileShape::Size()>{});
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace
|
||||
|
||||
/**
|
||||
* \brief Create local partition for thread (At now only packed partition
|
||||
* is supported).
|
||||
*
|
||||
* \param tensor Tensor for partition.
|
||||
* \param thread_lengths Layout of threads (could not be nested).
|
||||
* \param thread_id Thread index represented as integer.
|
||||
* \param projection Projection is used to remove selected dim from
|
||||
* partitioning. Use `slice(X)` to remove dimension, where X is dim
|
||||
* size. Use `Number<1>{}` to keep it.
|
||||
* \return Partition tensor.
|
||||
*/
|
||||
template <typename TensorType, typename ThreadLengthsTuple, typename ProjectionTuple>
|
||||
__host__ __device__ constexpr auto
|
||||
make_local_partition(TensorType& tensor,
|
||||
[[maybe_unused]] const ThreadLengthsTuple& thread_lengths,
|
||||
const index_t thread_id,
|
||||
const ProjectionTuple& projection)
|
||||
{
|
||||
static_assert(!IsNestedTuple(ThreadLengthsTuple{}));
|
||||
// Calculate new partition shape
|
||||
const auto& tensor_shape = shape(tensor);
|
||||
// Calculate projected thread lengths
|
||||
constexpr auto projected_thread_lengths =
|
||||
detail::ApplyProjection(ThreadLengthsTuple{}, ProjectionTuple{});
|
||||
constexpr auto partition_shape =
|
||||
detail::CalculateLocalPartitionShape(decltype(tensor_shape){}, projected_thread_lengths);
|
||||
// Create Thread Cluster Descriptor
|
||||
constexpr auto partition_shape_seq =
|
||||
generate_sequence_v2([&](auto I) { return size<I>(partition_shape); },
|
||||
Number<decltype(partition_shape)::Size()>{});
|
||||
constexpr auto thread_lengths_seq =
|
||||
generate_sequence_v2([&](auto I) { return size<I>(ThreadLengthsTuple{}); },
|
||||
Number<ThreadLengthsTuple::Size()>{});
|
||||
constexpr auto thread_cluster_desc_ = make_cluster_descriptor(thread_lengths_seq);
|
||||
// Calculate thread idxs and offsets
|
||||
const auto thread_idxs = thread_cluster_desc_.CalculateBottomIndex(make_multi_index(thread_id));
|
||||
// Apply projection on thread idxs to remove not needed idxs
|
||||
const auto projected_thread_idxs = detail::ApplyProjection(thread_idxs, projection);
|
||||
const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs(
|
||||
projected_thread_idxs, partition_shape_seq, tensor.GetMultiIdxOffsets());
|
||||
// Create new layout and tensor
|
||||
auto& unrolled_desc = layout(tensor).GetUnrolledDescriptor();
|
||||
const auto partition_layout =
|
||||
Layout<remove_reference_t<decltype(partition_shape)>, decltype(unrolled_desc)>(
|
||||
partition_shape, unrolled_desc);
|
||||
auto partition_tensor =
|
||||
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), partition_layout);
|
||||
// Apply offsets
|
||||
partition_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs));
|
||||
return partition_tensor;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Create local partition for thread (At now only packed partition
|
||||
* is supported).
|
||||
@@ -81,37 +234,12 @@ CalculateOffsetMultiIdxs(const ThreadIdxs& thread_idxs,
|
||||
* \return Partition tensor.
|
||||
*/
|
||||
template <typename TensorType, typename ThreadLengthsTuple>
|
||||
__host__ __device__ constexpr auto
|
||||
make_local_partition(TensorType& tensor,
|
||||
[[maybe_unused]] const ThreadLengthsTuple& thread_lengths,
|
||||
const index_t thread_id)
|
||||
__host__ __device__ constexpr auto make_local_partition(TensorType& tensor,
|
||||
const ThreadLengthsTuple& thread_lengths,
|
||||
const index_t thread_id)
|
||||
{
|
||||
static_assert(!IsNestedTuple(ThreadLengthsTuple{}));
|
||||
// Calculate new partition shape
|
||||
const auto& tensor_shape = shape(tensor);
|
||||
constexpr auto partition_shape =
|
||||
CalculateLocalPartitionShape(decltype(tensor_shape){}, ThreadLengthsTuple{});
|
||||
// Create Thread Cluster Descriptor
|
||||
constexpr auto partition_lengths_seq = generate_sequence_v2(
|
||||
[&](auto I) { return size<I>(partition_shape); }, Number<ThreadLengthsTuple::Size()>{});
|
||||
constexpr auto thread_lengths_seq =
|
||||
generate_sequence_v2([&](auto I) { return size<I>(ThreadLengthsTuple{}); },
|
||||
Number<ThreadLengthsTuple::Size()>{});
|
||||
constexpr auto thread_cluster_desc_ = make_cluster_descriptor(thread_lengths_seq);
|
||||
// Calculate thread idxs and offsets
|
||||
const auto thread_idxs = thread_cluster_desc_.CalculateBottomIndex(make_multi_index(thread_id));
|
||||
const auto offset_multi_idxs =
|
||||
CalculateOffsetMultiIdxs(thread_idxs, partition_lengths_seq, tensor.GetMultiIdxOffsets());
|
||||
// Create new layout and tensor
|
||||
auto& flatten_desc = layout(tensor).GetUnrolledDescriptor();
|
||||
const auto partition_layout =
|
||||
Layout<remove_reference_t<decltype(partition_shape)>, decltype(flatten_desc)>(
|
||||
partition_shape, flatten_desc);
|
||||
auto partition_tensor =
|
||||
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), partition_layout);
|
||||
// Apply offsets
|
||||
partition_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs));
|
||||
return partition_tensor;
|
||||
const auto projection = detail::GenerateDefaultProjection(ThreadLengthsTuple{});
|
||||
return make_local_partition(tensor, thread_lengths, thread_id, projection);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -125,22 +253,29 @@ make_local_partition(TensorType& tensor,
|
||||
* \param tensor Tensor for partition.
|
||||
* \param tile_shape Shapes of requested tile.
|
||||
* \param block_id Block index represented as integer.
|
||||
|
||||
* \param projection Projection to remove selected dim from partitioning.
|
||||
* slice(X) to remove, where X is dim size, Number<1>{} to keep.
|
||||
* \return Tile tensor.
|
||||
*/
|
||||
template <typename TensorType, typename BlockShapeTuple>
|
||||
__host__ __device__ constexpr auto
|
||||
make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, const index_t block_id)
|
||||
template <typename TensorType, typename BlockShapeTuple, typename ProjectionTuple>
|
||||
__host__ __device__ constexpr auto make_local_tile(const TensorType& tensor,
|
||||
const BlockShapeTuple& tile_shape,
|
||||
const index_t block_id,
|
||||
const ProjectionTuple& projection)
|
||||
{
|
||||
static_assert(!IsNestedTuple(BlockShapeTuple{}));
|
||||
|
||||
constexpr bool is_default_projection =
|
||||
is_same_v<ProjectionTuple, decltype(detail::GenerateDefaultProjection(BlockShapeTuple{}))>;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
auto& aligned_desc = layout(tensor).GetMergedNestingDescriptor();
|
||||
|
||||
if constexpr(BlockShapeTuple::Size() == I2)
|
||||
// TODO: Enable block_2_tile_map partitioning for non-default projection.
|
||||
if constexpr(BlockShapeTuple::Size() == I2 && is_default_projection)
|
||||
{
|
||||
// Optimized version for 2d tile shape [MxK]
|
||||
const auto block_2_tile_map =
|
||||
@@ -169,20 +304,24 @@ make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, con
|
||||
{
|
||||
// Calculate offsets
|
||||
// Sequence with data to process per block
|
||||
constexpr auto tile_shape_seq =
|
||||
generate_sequence_v2([](auto I) { return size(BlockShapeTuple{}.At(I)); },
|
||||
Number<BlockShapeTuple::Size()>{});
|
||||
constexpr auto projected_tile_shape =
|
||||
detail::ApplyProjection(BlockShapeTuple{}, ProjectionTuple{});
|
||||
using ProjectedTileShapeTuple = decltype(projected_tile_shape);
|
||||
constexpr auto projected_tile_shape_seq =
|
||||
generate_sequence_v2([](auto I) { return ProjectedTileShapeTuple{}.At(I); },
|
||||
Number<ProjectedTileShapeTuple::Size()>{});
|
||||
// Tuple with number of blocks
|
||||
const auto block_lengths = CalculateGridSize(shape(tensor), tile_shape);
|
||||
constexpr auto block_cluster_desc_ = make_cluster_descriptor(block_lengths);
|
||||
const auto block_lengths = detail::CalculateGridSize(shape(tensor), tile_shape, projection);
|
||||
const auto block_cluster_desc_ = make_cluster_descriptor(block_lengths);
|
||||
const auto block_idxs =
|
||||
block_cluster_desc_.CalculateBottomIndex(make_multi_index(block_id));
|
||||
const auto offset_multi_idxs =
|
||||
CalculateOffsetMultiIdxs(block_idxs, tile_shape_seq, tensor.GetMultiIdxOffsets());
|
||||
const auto projected_block_idxs = detail::ApplyProjection(block_idxs, projection);
|
||||
const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs(
|
||||
projected_block_idxs, projected_tile_shape_seq, tensor.GetMultiIdxOffsets());
|
||||
// Create new layout and tensor
|
||||
const auto tile_layout =
|
||||
Layout<remove_reference_t<decltype(tile_shape)>, decltype(aligned_desc)>(tile_shape,
|
||||
aligned_desc);
|
||||
Layout<remove_reference_t<ProjectedTileShapeTuple>, decltype(aligned_desc)>(
|
||||
projected_tile_shape, aligned_desc);
|
||||
auto tile_tensor =
|
||||
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), tile_layout);
|
||||
// Apply offsets
|
||||
@@ -191,5 +330,61 @@ make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, con
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Create local tile for thread block. (At now only packed tile
|
||||
* is supported).
|
||||
*
|
||||
* \note Currently to get the best performance please use 2d shape.
|
||||
*
|
||||
* \param tensor Tensor for partition.
|
||||
* \param tile_shape Shapes of requested tile.
|
||||
* \param block_id Block index represented as integer.
|
||||
* \return Tile tensor.
|
||||
*/
|
||||
template <typename TensorType, typename BlockShapeTuple>
|
||||
__host__ __device__ constexpr auto
|
||||
make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, const index_t block_id)
|
||||
{
|
||||
const auto projection = detail::GenerateDefaultProjection(BlockShapeTuple{});
|
||||
return make_local_tile(tensor, tile_shape, block_id, projection);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Pad tensor shapes to be adjusted to tile lengths.
|
||||
*
|
||||
*
|
||||
* \param tensor Tensor to pad.
|
||||
* \param tile_lengths Tile lengths to align tensor shape.
|
||||
* \return Padded tensor.
|
||||
*/
|
||||
template <typename TensorType, typename TileLengths>
|
||||
__host__ __device__ constexpr auto pad(const TensorType& tensor, const TileLengths& tile_lengths)
|
||||
{
|
||||
const auto& tensor_shape = shape(tensor);
|
||||
using TensorShapeType = remove_reference_t<decltype(tensor_shape)>;
|
||||
auto& unrolled_desc = layout(tensor).GetUnrolledDescriptor();
|
||||
// Generate sequence with ones to mark that all dims will be padded
|
||||
constexpr auto do_pads_seq =
|
||||
generate_sequence_v2([](auto) { return Number<1>{}; }, Number<TensorShapeType::Size()>{});
|
||||
// Create descriptor with padding
|
||||
auto padded_desc =
|
||||
tensor_operation::device::PadTensorDescriptor(unrolled_desc, tile_lengths, do_pads_seq);
|
||||
// Generate padded shape
|
||||
const auto padded_shape = generate_tuple(
|
||||
[&](auto i) {
|
||||
const auto& dim = size<i>(tensor_shape);
|
||||
const auto& tile_length = size<i>(tile_lengths);
|
||||
return ck::math::integer_divide_ceil(dim, tile_length) * tile_length;
|
||||
},
|
||||
Number<TileLengths::Size()>{});
|
||||
// Create layout and tensor
|
||||
const auto padded_layout =
|
||||
Layout<decltype(padded_shape), decltype(padded_desc)>(padded_shape, padded_desc);
|
||||
auto partition_tensor =
|
||||
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), padded_layout);
|
||||
partition_tensor.SetMultiIdxOffset(tensor.GetMultiIdxOffsets());
|
||||
return partition_tensor;
|
||||
}
|
||||
|
||||
} // namespace wrapper
|
||||
} // namespace ck
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/number.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
#include "ck/utility/tuple_helper.hpp"
|
||||
@@ -19,9 +20,9 @@ namespace wrapper {
|
||||
* \brief Memory type, allowed members:
|
||||
* - Generic,
|
||||
* - Global,
|
||||
* - LDS,
|
||||
* - SGPR,
|
||||
* - VGPR,
|
||||
* - Lds,
|
||||
* - Sgpr,
|
||||
* - Vgpr,
|
||||
*/
|
||||
using MemoryTypeEnum = AddressSpaceEnum;
|
||||
|
||||
@@ -52,12 +53,8 @@ struct Slice
|
||||
__host__ __device__ constexpr auto range(const T& dim) const
|
||||
{
|
||||
if constexpr(is_same_v<FromType, index_t> || is_same_v<ToType, index_t> ||
|
||||
is_same_v<T, index_t>)
|
||||
is_same_v<std::remove_const_t<T>, index_t>)
|
||||
{
|
||||
if(!(dim >= to_ && from_ >= 0 && (to_ < 0 || to_ > from_)))
|
||||
{
|
||||
throw std::runtime_error("Invalid range");
|
||||
}
|
||||
if(to_ < 0)
|
||||
{
|
||||
return dim - from_ + to_ + 1;
|
||||
@@ -70,9 +67,10 @@ struct Slice
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(dim >= to_ && from_ >= Number<0>{} && (to_ < 0 || to_ > from_),
|
||||
static_assert(T{} >= ToType{} && FromType{} >= Number<0>{} &&
|
||||
(ToType{} < 0 || ToType{} > FromType{}),
|
||||
"Invalid range");
|
||||
if constexpr(to_ < 0)
|
||||
if constexpr(ToType{} < 0)
|
||||
{
|
||||
return dim - from_ + to_ + Number<1>{};
|
||||
}
|
||||
@@ -130,6 +128,23 @@ constexpr auto make_register_tensor(const Layout<Shape, UnrolledDescriptorType>&
|
||||
return Tensor<MemoryType, ElementType, Shape, UnrolledDescriptorType>(layout);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Clear tensor. (Only for Vpgr/Sgpr)
|
||||
*
|
||||
* \param tensor Tensor to be cleared.
|
||||
*/
|
||||
template <MemoryTypeEnum BufferAddressSpace,
|
||||
typename ElementType,
|
||||
typename Shape,
|
||||
typename UnrolledDescriptorType>
|
||||
__host__ __device__ void
|
||||
clear(Tensor<BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType>& tensor)
|
||||
{
|
||||
static_assert(
|
||||
!Tensor<BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType>::IsDynamicBuffer);
|
||||
return tensor.GetBuffer().Clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get Tensor Layout.
|
||||
*
|
||||
|
||||
Reference in New Issue
Block a user