fix merge from upstream

This commit is contained in:
carlushuang
2024-03-26 14:09:54 +00:00
parent c94b545747
commit 04ee01191a
105 changed files with 16558 additions and 2285 deletions

View File

@@ -5,8 +5,11 @@
#include "ck/wrapper/utils/layout_utils.hpp"
// Disable from doxygen docs generation
/// @cond INTERNAL
namespace ck {
namespace wrapper {
/// @endcond
/**
* \brief Layout wrapper that performs the tensor descriptor logic.
@@ -19,6 +22,8 @@ namespace wrapper {
template <typename Shape, typename UnrolledDescriptorType>
struct Layout
{
// Disable from doxygen docs generation
/// @cond INTERNAL
private:
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
@@ -246,6 +251,7 @@ struct Layout
using Descriptor1dType =
remove_cvref_t<decltype(MakeMerge1d(Shape{}, UnrolledDescriptorType{}))>;
using DefaultIdxsTupleType = remove_cvref_t<decltype(GenerateDefaultIdxsTuple(Shape{}))>;
/// @endcond
public:
using LayoutShape = Shape;
@@ -457,6 +463,8 @@ struct Layout
return unrolled_descriptor_;
}
// Disable from doxygen docs generation
/// @cond INTERNAL
private:
// All dimensions are unrolled
UnrolledDescriptorType unrolled_descriptor_;
@@ -469,6 +477,7 @@ struct Layout
// Descriptor1dType lengths: (8)
// MergedNestsDescriptorType lengths: (4, 2)
const Shape shape_;
/// @endcond
};
} // namespace wrapper

View File

@@ -12,8 +12,11 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
// Disable from doxygen docs generation
/// @cond INTERNAL
namespace ck {
namespace wrapper {
/// @endcond
/**
* \brief Perform optimized copy between two tensors partitions (threadwise copy).
@@ -61,12 +64,12 @@ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor)
decltype(dim_access_order),
VectorDim,
ScalarPerVector,
Sequence<false>,
Sequence<false>>{in_grid_desc,
make_tuple(src_tensor.GetMultiIdxOffsets()),
out_grid_desc,
make_tuple(dst_tensor.GetMultiIdxOffsets()),
tensor_operation::element_wise::PassThrough{}};
Sequence<true>,
Sequence<true>>{in_grid_desc,
make_tuple(src_tensor.GetMultiIdxOffsets()),
out_grid_desc,
make_tuple(dst_tensor.GetMultiIdxOffsets()),
tensor_operation::element_wise::PassThrough{}};
transfer.Run(tie(in_grid_desc),
tie(src_tensor.GetBuffer()),
@@ -104,37 +107,25 @@ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor)
else if constexpr(SrcTensorType::IsDynamicBuffer && !DstTensorType::IsDynamicBuffer)
{
// Perform copy from DynamicBuffer to StaticBuffer
const auto src_dst_slice_origin =
const auto dst_slice_origin_idxs =
generate_tuple([&](auto) { return I0; }, Number<num_dims>{});
constexpr auto src_vector_tensor_lengths = generate_sequence_v2(
[&](auto I) {
if constexpr(I == VectorDim)
{
return Number<ScalarPerVector>{};
}
else
{
return I1;
}
},
Number<num_dims>{});
auto transfer =
ThreadwiseTensorSliceTransfer_v4r1<typename SrcTensorType::TensorElementType,
typename DstTensorType::TensorElementType,
remove_cvref_t<decltype(in_grid_desc)>,
remove_cvref_t<decltype(out_grid_desc)>,
decltype(thread_slice_lengths),
decltype(dim_access_order),
decltype(src_vector_tensor_lengths),
decltype(dim_access_order)>{
src_tensor.GetMultiIdxOffsets()};
auto transfer = ThreadwiseTensorSliceTransfer_v2<
std::remove_const_t<typename SrcTensorType::TensorElementType>,
std::remove_const_t<typename DstTensorType::TensorElementType>,
remove_cvref_t<decltype(in_grid_desc)>,
remove_cvref_t<decltype(out_grid_desc)>,
decltype(thread_slice_lengths),
decltype(dim_access_order),
VectorDim,
ScalarPerVector,
I1,
false,
false>{in_grid_desc, src_tensor.GetMultiIdxOffsets()};
transfer.Run(in_grid_desc,
src_dst_slice_origin,
src_tensor.GetBuffer(),
out_grid_desc,
src_dst_slice_origin,
dst_slice_origin_idxs,
dst_tensor.GetBuffer());
}
else
@@ -183,10 +174,12 @@ template <typename DimAccessOrderTuple,
index_t ScalarPerVector,
typename SrcTensorType,
typename DstTensorType,
typename ThreadLayoutTuple>
__device__ void blockwise_copy(const SrcTensorType& src_tensor,
DstTensorType& dst_tensor,
[[maybe_unused]] ThreadLayoutTuple& thread_layout)
typename ThreadShape,
typename ThreadUnrolledDesc>
__device__ void
blockwise_copy(const SrcTensorType& src_tensor,
DstTensorType& dst_tensor,
[[maybe_unused]] const Layout<ThreadShape, ThreadUnrolledDesc>& thread_layout)
{
static_assert(SrcTensorType::IsDynamicBuffer && DstTensorType::IsDynamicBuffer);
static_assert(is_detected<is_tuple, DimAccessOrderTuple>::value);
@@ -199,12 +192,12 @@ __device__ void blockwise_copy(const SrcTensorType& src_tensor,
constexpr auto tile_lengths_seq =
generate_sequence_v2([](auto I) { return size(SrcShapeType{}.At(I)); }, Number<num_dims>{});
constexpr auto thread_layout_seq = generate_sequence_v2(
[](auto I) { return size(ThreadLayoutTuple{}.At(I)); }, Number<num_dims>{});
constexpr auto thread_layout_seq =
generate_sequence_v2([](auto I) { return size<I>(ThreadShape{}); }, Number<num_dims>{});
constexpr auto dim_access_order = generate_sequence_v2(
[](auto I) { return DimAccessOrderTuple{}.At(I); }, Number<num_dims>{});
using ThisThreadBlock = ThisThreadBlock<size(ThreadLayoutTuple{})>;
using ThisThreadBlock = ThisThreadBlock<size(ThreadShape{})>;
// Perform copy between DynamicBuffers
auto transfer = ThreadGroupTensorSliceTransfer_v7<

View File

@@ -9,9 +9,14 @@
#include "ck/host_utility/device_prop.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
// Disable from doxygen docs generation
/// @cond INTERNAL
namespace ck {
namespace wrapper {
/// @endcond
// Disable from doxygen docs generation
/// @cond INTERNAL
namespace {
namespace detail {
/**
@@ -45,11 +50,13 @@ __device__ constexpr auto GetBlockDescriptor()
} // namespace detail
} // namespace
/// @endcond
/**
* \brief Perform blockwise gemm xdl on tensors stored in lds. Result will be
* stored in Vgpr register. A data layout must be (MPerBlock, KPerBlock) and B
* data layout must be (NPerBlock, KPerBlock).
* stored in Vgpr register. A data layout must be (MPerBlock, KPerBlock) or
* (K0PerBlock, MPerBlock, K1) and B data layout must be (NPerBlock, KPerBlock)
* or (K0PerBlock, NPerBlock, K1).
*
* \note C output Vgpr register layout (8D):
* - MXdlPerWave - The number of MFMA instructions run by single wave in M
@@ -71,9 +78,9 @@ __device__ constexpr auto GetBlockDescriptor()
* \tparam BlockSize Tensor to pad.
* \tparam GemmTraits Traits of gemm xdl operation.
* \param a_local_tile_tensor A tensor in LDS memory for blockwise gemm
* (MPerBlock, KPerBlock) layout.
* (MPerBlock, KPerBlock) or (K0PerBlock, MPerBlock, K1) layout.
* \param b_local_tile_tensor B tensor in LDS memory for blockwise gemm
* (NPerBlock, KPerBlock) layout.
* (NPerBlock, KPerBlock) or (K0PerBlock, NPerBlock, K1) layout.
* \param c_reg_tensor C tensor VGPR memory for blockwise gemm.
*/
template <typename DataType,
@@ -86,6 +93,8 @@ __device__ void blockwise_gemm_xdl(const ATensorType& a_local_tile_tensor,
const BTensorType& b_local_tile_tensor,
CTensorType& c_reg_tensor)
{
constexpr auto I3 = Number<3>{};
static_assert(ATensorType::TensorBufferAddressSpace == MemoryTypeEnum::Lds);
static_assert(BTensorType::TensorBufferAddressSpace == MemoryTypeEnum::Lds);
static_assert(CTensorType::TensorBufferAddressSpace == MemoryTypeEnum::Vgpr);
@@ -99,10 +108,18 @@ __device__ void blockwise_gemm_xdl(const ATensorType& a_local_tile_tensor,
using ATileLayout = remove_cvref_t<decltype(layout(a_local_tile_tensor))>;
using BTileLayout = remove_cvref_t<decltype(layout(b_local_tile_tensor))>;
static_assert(typename ATileLayout::LayoutShape{}.Size() ==
typename BTileLayout::LayoutShape{}.Size());
constexpr bool is_3d_desc = typename ATileLayout::LayoutShape{}.Size() == I3;
using ABlockDesc_K0_M_K1_Type =
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>());
conditional_t<is_3d_desc,
typename ATileLayout::LayoutUnrolledDescriptorType,
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>())>;
using BBlockDesc_K0_N_K1_Type =
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>());
conditional_t<is_3d_desc,
typename BTileLayout::LayoutUnrolledDescriptorType,
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>())>;
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
DataType,
@@ -168,14 +185,22 @@ make_blockwise_gemm_xdl_c_local_partition(CTensorType& c_local_tile_tensor)
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
static_assert(typename ATileLayout::LayoutShape{}.Size() ==
typename BTileLayout::LayoutShape{}.Size());
constexpr bool is_integer =
is_same_v<DataType, int8_t> || is_same_v<DataType, int16_t> || is_same_v<DataType, int32_t>;
using GemmAccDataType = std::conditional_t<is_integer, int32_t, float>;
constexpr bool is_3d_desc = typename ATileLayout::LayoutShape{}.Size() == I3;
using ABlockDesc_K0_M_K1_Type =
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>());
conditional_t<is_3d_desc,
typename ATileLayout::LayoutUnrolledDescriptorType,
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>())>;
using BBlockDesc_K0_N_K1_Type =
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>());
conditional_t<is_3d_desc,
typename BTileLayout::LayoutUnrolledDescriptorType,
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>())>;
using BlockwiseGemmXdlops =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
@@ -233,19 +258,45 @@ make_blockwise_gemm_xdl_c_local_partition(CTensorType& c_local_tile_tensor)
const auto partition_desc = BlockwiseGemmXdlops::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(
layout(c_local_tile_tensor).GetUnrolledDescriptor());
const auto lower_upper_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<8>{});
auto sliced_desc = transform_tensor_descriptor(
partition_desc,
make_tuple(
make_slice_transform(partition_shape.At(Number<0>{}),
m_thread_data_on_grid_idx[I0],
partition_shape.At(Number<0>{}) + m_thread_data_on_grid_idx[I0]),
make_slice_transform(partition_shape.At(Number<1>{}),
n_thread_data_on_grid_idx[I0],
partition_shape.At(Number<1>{}) + n_thread_data_on_grid_idx[I0]),
make_slice_transform(partition_shape.At(Number<2>{}),
m_thread_data_on_grid_idx[I1],
partition_shape.At(Number<2>{}) + m_thread_data_on_grid_idx[I1]),
make_slice_transform(partition_shape.At(Number<3>{}),
n_thread_data_on_grid_idx[I1],
partition_shape.At(Number<3>{}) + n_thread_data_on_grid_idx[I1]),
make_slice_transform(partition_shape.At(Number<4>{}),
m_thread_data_on_grid_idx[I2],
partition_shape.At(Number<4>{}) + m_thread_data_on_grid_idx[I2]),
make_slice_transform(partition_shape.At(Number<5>{}),
m_thread_data_on_grid_idx[I3],
partition_shape.At(Number<5>{}) + m_thread_data_on_grid_idx[I3]),
make_slice_transform(partition_shape.At(Number<6>{}),
m_thread_data_on_grid_idx[I4],
partition_shape.At(Number<6>{}) + m_thread_data_on_grid_idx[I4]),
make_slice_transform(partition_shape.At(Number<7>{}),
n_thread_data_on_grid_idx[I2],
partition_shape.At(Number<7>{}) + n_thread_data_on_grid_idx[I2])),
lower_upper_dims,
lower_upper_dims);
const auto partition_layout =
Layout<remove_reference_t<decltype(partition_shape)>, decltype(partition_desc)>(
partition_shape, partition_desc);
Layout<remove_reference_t<decltype(partition_shape)>, decltype(sliced_desc)>(
partition_shape, sliced_desc);
auto partition_tensor = make_tensor<CTensorType::TensorBufferAddressSpace>(
c_local_tile_tensor.GetPointer(), partition_layout);
partition_tensor.SetMultiIdxOffset(make_multi_index(m_thread_data_on_grid_idx[I0],
n_thread_data_on_grid_idx[I0],
m_thread_data_on_grid_idx[I1],
n_thread_data_on_grid_idx[I1],
m_thread_data_on_grid_idx[I2],
m_thread_data_on_grid_idx[I3],
m_thread_data_on_grid_idx[I4],
n_thread_data_on_grid_idx[I2]));
return partition_tensor;
}
@@ -292,14 +343,22 @@ __host__ __device__ constexpr auto make_blockwise_gemm_xdl_c_vgpr()
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
static_assert(typename ATileLayout::LayoutShape{}.Size() ==
typename BTileLayout::LayoutShape{}.Size());
constexpr bool is_integer =
is_same_v<DataType, int8_t> || is_same_v<DataType, int16_t> || is_same_v<DataType, int32_t>;
using GemmAccDataType = std::conditional_t<is_integer, int32_t, float>;
constexpr bool is_3d_desc = typename ATileLayout::LayoutShape{}.Size() == I3;
using ABlockDesc_K0_M_K1_Type =
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>());
conditional_t<is_3d_desc,
typename ATileLayout::LayoutUnrolledDescriptorType,
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>())>;
using BBlockDesc_K0_N_K1_Type =
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>());
conditional_t<is_3d_desc,
typename BTileLayout::LayoutUnrolledDescriptorType,
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>())>;
using BlockwiseGemmXdlops =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
@@ -326,9 +385,8 @@ __host__ __device__ constexpr auto make_blockwise_gemm_xdl_c_vgpr()
const auto vgpr_layout = Layout<remove_reference_t<decltype(vgpr_shape)>, decltype(vgpr_desc)>(
vgpr_shape, vgpr_desc);
// Get vector type for Vgpr
using BlockwiseGemmCThreadBufferType =
remove_reference_t<decltype(BlockwiseGemmXdlops{}.GetCThreadBuffer())>;
using VgprVectorType = typename BlockwiseGemmCThreadBufferType::V;
constexpr index_t ScalarPerVector = BlockwiseGemmXdlops::xdlops_gemm.GetRegSizePerXdlops();
using VgprVectorType = typename vector_type<GemmAccDataType, ScalarPerVector>::type;
return ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, VgprVectorType>(
vgpr_layout);
}

View File

@@ -7,9 +7,14 @@
#include "utils/tensor_partition.hpp"
#include "utils/layout_utils.hpp"
// Disable from doxygen docs generation
/// @cond INTERNAL
namespace ck {
namespace wrapper {
/// @endcond
// Disable from doxygen docs generation
/// @cond INTERNAL
namespace {
namespace detail {
/**
@@ -172,10 +177,10 @@ __host__ __device__ constexpr auto GenerateUpperDims(const Tuple<Transforms...>&
}
}
template <typename... Ts, typename Shape, typename FlattenDescriptor>
template <typename... Ts, typename Shape, typename UnrolledDescriptor>
__host__ __device__ constexpr auto GenerateSlicedDescriptor(const Tuple<Ts...>& idx,
const Shape& shape,
const FlattenDescriptor& flatten_desc)
const UnrolledDescriptor& flatten_desc)
{
constexpr auto old_shape_dims = decltype(UnrollNestedTuple(shape))::Size();
@@ -189,6 +194,7 @@ __host__ __device__ constexpr auto GenerateSlicedDescriptor(const Tuple<Ts...>&
}
} // namespace detail
} // namespace
/// @endcond
/**
* \brief Tensor wrapper that performs static and dynamic buffer logic.
@@ -394,6 +400,8 @@ struct Tensor
}
private:
// Disable from doxygen docs generation
/// @cond INTERNAL
using DynamicBufferType = DynamicBuffer<BufferAddressSpace,
ElementType,
ElementSpaceSize,
@@ -428,6 +436,7 @@ struct Tensor
// tensor descriptor (thus all it's transforms) and is linear (1D).
// We store base_offset_ to avoid multiple recalculations.
index_t base_offset_;
/// @endcond
};
} // namespace wrapper

View File

@@ -5,8 +5,11 @@
#include "ck/ck.hpp"
// Disable from doxygen docs generation
/// @cond INTERNAL
namespace ck {
namespace wrapper {
/// @endcond
/**
* \brief Traits for blockwise gemm xdl.
@@ -20,48 +23,57 @@ namespace wrapper {
* \tparam K1Value The number of K-dim elements that are packed together as
* a separate logical dimension. Usually aligns with vector load size.
*/
template <index_t MPerXDLValue,
index_t NPerXDLValue,
index_t MXdlPerWaveValue,
index_t NXdlPerWaveValue,
index_t K1Value>
template <typename MPerXDLValue,
typename NPerXDLValue,
typename MXdlPerWaveValue,
typename NXdlPerWaveValue,
typename K1Value>
struct BlockwisGemmXdlTraits
{
static constexpr index_t MPerXDL = MPerXDLValue;
static constexpr index_t NPerXDL = NPerXDLValue;
static constexpr index_t MXdlPerWave = MXdlPerWaveValue;
static constexpr index_t NXdlPerWave = NXdlPerWaveValue;
static constexpr index_t K1 = K1Value;
static constexpr auto MPerXDL = MPerXDLValue{};
static constexpr auto NPerXDL = NPerXDLValue{};
static constexpr auto MXdlPerWave = MXdlPerWaveValue{};
static constexpr auto NXdlPerWave = NXdlPerWaveValue{};
static constexpr auto K1 = K1Value{};
};
// K1 = 4
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 4>
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<4>, Number<2>, Number<4>>
{
};
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 2, 4, 4>
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_4K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<4>, Number<4>>
{
};
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 4>
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<2>, Number<4>>
{
};
// K1 = 8
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_8K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 8>
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_8K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<4>, Number<2>, Number<8>>
{
};
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_8K1 : BlockwisGemmXdlTraits<32, 32, 2, 4, 8>
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_8K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<4>, Number<8>>
{
};
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 8>
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<2>, Number<8>>
{
};
// K1 = 16
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_16K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 16>
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_16K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<4>, Number<2>, Number<16>>
{
};
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_16K1 : BlockwisGemmXdlTraits<32, 32, 2, 4, 16>
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_16K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<4>, Number<16>>
{
};
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 16>
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<2>, Number<16>>
{
};

View File

@@ -0,0 +1,17 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
// Disable from doxygen docs generation
/// @cond INTERNAL
namespace ck {
namespace wrapper {
/// @endcond
#define __CK_WRAPPER_LAUNCH_BOUNDS__ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
} // namespace wrapper
} // namespace ck

View File

@@ -15,12 +15,16 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
namespace ck {
namespace wrapper {
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
// Disable from doxygen docs generation
/// @cond
/// @cond INTERNAL
namespace ck {
namespace wrapper {
/// @endcond
// Disable from doxygen docs generation
/// @cond INTERNAL
// forward declaration
template <typename Shape, typename UnrolledDescriptorType>
struct Layout;
@@ -29,6 +33,7 @@ template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple());
namespace {
namespace detail {
/**
* \brief Generate packed (column-major) strides if not passed
*
@@ -83,6 +88,7 @@ __host__ __device__ constexpr auto MakeUnrolledDescriptor(const LayoutShape& sha
return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
}
}
} // namespace detail
} // namespace
/// @endcond
@@ -98,8 +104,9 @@ __host__ __device__ constexpr auto MakeUnrolledDescriptor(const LayoutShape& sha
template <typename Shape, typename Strides>
__host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides& strides)
{
using UnrolledDescriptorType = decltype(MakeUnrolledDescriptor(Shape{}, Strides{}));
return Layout<Shape, UnrolledDescriptorType>(shape, MakeUnrolledDescriptor(shape, strides));
using UnrolledDescriptorType = decltype(detail::MakeUnrolledDescriptor(Shape{}, Strides{}));
return Layout<Shape, UnrolledDescriptorType>(shape,
detail::MakeUnrolledDescriptor(shape, strides));
}
/**
@@ -112,13 +119,12 @@ __host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides
template <typename Shape>
__host__ __device__ constexpr auto make_layout(const Shape& shape)
{
using UnrolledDescriptorType = decltype(MakeUnrolledDescriptor(Shape{}, Tuple<>{}));
return Layout<Shape, UnrolledDescriptorType>(shape, MakeUnrolledDescriptor(shape, Tuple<>{}));
using UnrolledDescriptorType = decltype(detail::MakeUnrolledDescriptor(Shape{}, Tuple<>{}));
return Layout<Shape, UnrolledDescriptorType>(shape,
detail::MakeUnrolledDescriptor(shape, Tuple<>{}));
}
// Layout helpers
// get
/**
* \private
* \brief Get dim.
@@ -152,8 +158,8 @@ __host__ __device__ constexpr auto get(const Tuple<Dims...>& tuple)
* \param layout Layout to create sub layout.
* \return Requsted sub layout.
*/
template <index_t idx, typename Shape, typename FlattenDesc>
__host__ __device__ constexpr auto get(const Layout<Shape, FlattenDesc>& layout)
template <index_t idx, typename Shape, typename UnrolledDesc>
__host__ __device__ constexpr auto get(const Layout<Shape, UnrolledDesc>& layout)
{
const auto& shape = layout.GetShape();
const auto new_shape = get<idx>(shape);
@@ -427,5 +433,91 @@ __host__ __device__ constexpr const auto& shape(const LayoutType& layout)
return layout.GetShape();
}
// pad
/**
* \brief Pad layout shapes to be adjusted to tile lengths.
*
*
* \param layout Layout to pad.
* \param tile_lengths Tile lengths to align layout shape.
* \return Padded layout.
*/
template <typename Shape, typename UnrolledDesc, typename TileLengths>
__host__ __device__ constexpr auto pad(const Layout<Shape, UnrolledDesc>& layout,
const TileLengths& tile_lengths)
{
auto& unrolled_desc = layout.GetUnrolledDescriptor();
// Generate sequence with ones to mark that all dims will be padded
constexpr auto do_pads_seq =
generate_sequence_v2([](auto) { return Number<1>{}; }, Number<Shape::Size()>{});
// Create descriptor with padding
auto padded_desc =
tensor_operation::device::PadTensorDescriptor(unrolled_desc, tile_lengths, do_pads_seq);
// Generate padded shape
const auto padded_shape = generate_tuple(
[&](auto i) { return padded_desc.GetLength(Number<i>{}); }, Number<TileLengths::Size()>{});
// Create layout
return Layout<decltype(padded_shape), decltype(padded_desc)>(padded_shape, padded_desc);
}
// unmerge
/**
* \brief Unmerge selected dim in layout.
*
* \tparam Idx Index to dimension being unmerged.
* \param layout Layout to pad.
* \param new_lengths Dimensions into which the indicated dimension will be divided.
* \param new_indexes Indexes to shuffle dims. Dims for unmerged dim should be nested.
* \return Unmerged layout.
*/
template <index_t Idx, typename Shape, typename UnrolledDesc, typename NewLengths, typename NewIdxs>
__host__ __device__ constexpr auto unmerge(const Layout<Shape, UnrolledDesc>& layout,
const NewLengths& new_lengths,
[[maybe_unused]] const NewIdxs& new_indexes)
{
const auto& layout_shape = shape(layout);
auto& unrolled_desc = layout.GetUnrolledDescriptor();
constexpr auto dims = Shape::Size();
// Generate transforms
const auto transforms = generate_tuple(
[&](auto i) {
if constexpr(i == Idx)
{
return make_unmerge_transform(new_lengths);
}
else
{
return make_pass_through_transform(layout_shape.At(i));
}
},
Number<dims>{});
constexpr auto lower_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<dims>{});
constexpr auto upper_dims = generate_tuple(
[&](auto i) {
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, NewIdxs>>::value)
{
constexpr auto idxs_tuple = tuple_element_t<i.value, NewIdxs>{};
return to_sequence(idxs_tuple);
}
else
{
constexpr index_t index = tuple_element_t<i.value, NewIdxs>{};
return Sequence<index>{};
}
},
Number<dims>{});
const auto unmerged_desc =
transform_tensor_descriptor(unrolled_desc, transforms, lower_dims, upper_dims);
const auto unmerged_shape =
generate_tuple([&](auto i) { return unmerged_desc.GetLength(Number<i>{}); },
Number<decltype(unmerged_desc)::GetNumOfVisibleDimension()>{});
// Create layout
return Layout<decltype(unmerged_shape), decltype(unmerged_desc)>(unmerged_shape, unmerged_desc);
}
} // namespace wrapper
} // namespace ck

View File

@@ -6,13 +6,17 @@
#include "tensor_utils.hpp"
#include "layout_utils.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
// Disable from doxygen docs generation
/// @cond INTERNAL
namespace ck {
namespace wrapper {
/// @endcond
// Disable from doxygen docs generation
/// @cond INTERNAL
namespace {
namespace detail {
@@ -44,8 +48,9 @@ __host__ __device__ constexpr auto CalculateLocalPartitionShape(const Tuple<Ts..
* \brief Apply projection.
*
* \param base_tuple Tuple to apply projection.
* \param projection Projection to remove selected dim from partitioning.
* slice(X) to remove, where X is dim size, Number<1>{} to keep.
* \param projection Projection is used to remove selected dim from
* partitioning. Use `slice(X)` to remove dimension, where X is dim
* size. Use `Number<1>{}` to keep it.
* \return Multi index after projection.
*/
template <typename MultiIndex, typename ProjectionTuple>
@@ -73,7 +78,7 @@ ApplyProjection([[maybe_unused]] const MultiIndex& base_tuple,
}
else
{
return base_tuple.At(i_num);
return make_tuple(base_tuple.At(i_num));
}
},
Number<MultiIndex::Size()>{});
@@ -86,8 +91,9 @@ ApplyProjection([[maybe_unused]] const MultiIndex& base_tuple,
* \brief Calculate shape with dims from projection.
*
* \param shape Base tensor shape.
* \param projection Projection to remove selected dim from partitioning.
* slice(X) to remove, where X is dim size, Number<1>{} to keep.
* \param projection Projection is used to remove selected dim from
* partitioning. Use `slice(X)` to remove dimension, where X is dim
* size. Use `Number<1>{}` to keep it.
* \return Shape with dims from projection
*/
template <typename... Ts, typename... Ps>
@@ -119,22 +125,14 @@ __host__ __device__ constexpr auto CalculateShapeWithProjection(const Tuple<Ts..
*
* \param shape Base tensor shape.
* \param tile_shape Tile shape.
* \param projection Projection is used to remove selected dim from
* partitioning. Use `slice(X)` to remove dimension, where X is dim
* size. Use `Number<1>{}` to keep it.
* \return Tuple with blocks number.
*/
template <typename... Ts, typename... Ls, typename... Ps>
__host__ __device__ constexpr auto CalculateGridSize(const Tuple<Ts...>& shape,
const Tuple<Ls...>& tile_shape,
const Tuple<Ps...>& projection)
const Tuple<Ls...>& tile_shape)
{
auto shape_with_projection = CalculateShapeWithProjection(shape, projection);
return generate_tuple(
[&](auto i) {
return ck::math::integer_divide_ceil(size<i>(shape_with_projection),
size<i>(tile_shape));
},
[&](auto i) { return ck::math::integer_divide_ceil(size<i>(shape), size<i>(tile_shape)); },
Number<Tuple<Ls...>::Size()>{});
}
@@ -155,6 +153,54 @@ CalculateOffsetMultiIdxs(const ThreadIdxs& thread_idxs,
return thread_idxs * partition_lengths_seq + old_offset_idxs;
}
/**
* \brief Select dims to partition (skip if slice).
*
* \param block_idxs Input block indexes.
* \return Partitioned dims.
*/
template <typename BlockIdxs>
__host__ __device__ constexpr auto GetDimsToPartition([[maybe_unused]] const BlockIdxs& block_idxs)
{
const auto dims_to_partition = generate_tuple(
[&](auto i) {
if constexpr(!is_detected<is_slice, tuple_element_t<i, BlockIdxs>>::value)
{
return Number<i>{};
}
else
{
return Tuple<>{};
}
},
Number<BlockIdxs::Size()>{});
// Remove empty tuples
return UnrollNestedTuple<0, 1>(dims_to_partition);
}
/**
* \brief Replace slices with zeros (Slice dims are not partitioned).
*
* \param block_idxs Input block indexes.
* \return Parsed dims.
*/
template <typename BlockIdxs>
__host__ __device__ constexpr auto ReplaceSlicesWithZeros(const BlockIdxs& block_idxs)
{
return generate_tuple(
[&](auto i) {
if constexpr(!is_detected<is_slice, tuple_element_t<i, BlockIdxs>>::value)
{
return block_idxs.At(i);
}
else
{
return Number<0>{};
}
},
Number<BlockIdxs::Size()>{});
}
/**
* \brief Calculate default projection.
*
@@ -168,59 +214,96 @@ GenerateDefaultProjection([[maybe_unused]] const TileShape tile_shape)
return generate_tuple([&](auto) { return Number<1>{}; }, Number<TileShape::Size()>{});
}
/**
* \brief Calculate thread multi index from 1d thread index.
*
* \param thread_layout Layout of threads (could not be nested).
* \param thread_id Thread index represented as integer.
* \return Multi index.
*/
template <typename ThreadShape, typename ThreadUnrolledDesc>
__host__ __device__ constexpr auto CalculateThreadMultiIdx(
[[maybe_unused]] const Layout<ThreadShape, ThreadUnrolledDesc>& thread_layout,
const index_t thread_id)
{
static_assert(ThreadUnrolledDesc::GetNumOfTransform() == 1,
"Thread layout should not be transformed.");
constexpr auto embed_transform = ThreadUnrolledDesc{}.GetTransforms().At(Number<0>{});
constexpr auto shape = ThreadShape{};
constexpr auto strides = embed_transform.coefficients_;
return generate_tuple(
[&](auto i) {
constexpr auto num_i = Number<i>{};
return (thread_id / strides.At(num_i)) % shape.At(num_i);
},
Number<ThreadShape::Size()>{});
}
} // namespace detail
} // namespace
/// @endcond
/**
* \brief Create local partition for thread (At now only packed partition
* is supported).
*
* \param tensor Tensor for partition.
* \param thread_lengths Layout of threads (could not be nested).
* \param thread_layout Layout of threads (could not be transformed).
* \param thread_id Thread index represented as integer.
* \param projection Projection is used to remove selected dim from
* partitioning. Use `slice(X)` to remove dimension, where X is dim
* size. Use `Number<1>{}` to keep it.
* \return Partition tensor.
*/
template <typename TensorType, typename ThreadLengthsTuple, typename ProjectionTuple>
template <typename TensorType,
typename ThreadShape,
typename ThreadUnrolledDesc,
typename ProjectionTuple>
__host__ __device__ constexpr auto
make_local_partition(TensorType& tensor,
[[maybe_unused]] const ThreadLengthsTuple& thread_lengths,
[[maybe_unused]] const Layout<ThreadShape, ThreadUnrolledDesc>& thread_layout,
const index_t thread_id,
const ProjectionTuple& projection)
{
static_assert(!IsNestedTuple(ThreadLengthsTuple{}));
static_assert(!IsNestedTuple(ThreadShape{}));
// Calculate new partition shape
const auto& tensor_shape = shape(tensor);
// Calculate projected thread lengths
constexpr auto projected_thread_lengths =
detail::ApplyProjection(ThreadLengthsTuple{}, ProjectionTuple{});
detail::ApplyProjection(ThreadShape{}, ProjectionTuple{});
constexpr auto partition_shape =
detail::CalculateLocalPartitionShape(decltype(tensor_shape){}, projected_thread_lengths);
// Create Thread Cluster Descriptor
constexpr auto partition_shape_seq =
generate_sequence_v2([&](auto I) { return size<I>(partition_shape); },
Number<decltype(partition_shape)::Size()>{});
constexpr auto thread_lengths_seq =
generate_sequence_v2([&](auto I) { return size<I>(ThreadLengthsTuple{}); },
Number<ThreadLengthsTuple::Size()>{});
constexpr auto thread_cluster_desc_ = make_cluster_descriptor(thread_lengths_seq);
// Calculate thread idxs and offsets
const auto thread_idxs = thread_cluster_desc_.CalculateBottomIndex(make_multi_index(thread_id));
const auto thread_idxs = detail::CalculateThreadMultiIdx(thread_layout, thread_id);
// Apply projection on thread idxs to remove not needed idxs
const auto projected_thread_idxs = detail::ApplyProjection(thread_idxs, projection);
const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs(
projected_thread_idxs, partition_shape_seq, tensor.GetMultiIdxOffsets());
// Create new layout and tensor
auto& unrolled_desc = layout(tensor).GetUnrolledDescriptor();
// Slice descriptor
const auto transforms = generate_tuple(
[&](auto i) {
return make_slice_transform(partition_shape.At(i),
offset_multi_idxs.At(i),
partition_shape.At(i) + offset_multi_idxs.At(i));
},
Number<remove_reference_t<decltype(tensor_shape)>::Size()>{});
const auto lower_upper_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; },
Number<remove_reference_t<decltype(tensor_shape)>::Size()>{});
auto sliced_desc =
transform_tensor_descriptor(unrolled_desc, transforms, lower_upper_dims, lower_upper_dims);
// Create layout
const auto partition_layout =
Layout<remove_reference_t<decltype(partition_shape)>, decltype(unrolled_desc)>(
partition_shape, unrolled_desc);
Layout<remove_reference_t<decltype(partition_shape)>, decltype(sliced_desc)>(
partition_shape, sliced_desc);
auto partition_tensor =
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), partition_layout);
// Apply offsets
partition_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs));
return partition_tensor;
}
@@ -233,12 +316,13 @@ make_local_partition(TensorType& tensor,
* \param thread_id Thread index represented as integer.
* \return Partition tensor.
*/
template <typename TensorType, typename ThreadLengthsTuple>
__host__ __device__ constexpr auto make_local_partition(TensorType& tensor,
const ThreadLengthsTuple& thread_lengths,
const index_t thread_id)
template <typename TensorType, typename ThreadShape, typename ThreadUnrolledDesc>
__host__ __device__ constexpr auto
make_local_partition(TensorType& tensor,
const Layout<ThreadShape, ThreadUnrolledDesc>& thread_lengths,
const index_t thread_id)
{
const auto projection = detail::GenerateDefaultProjection(ThreadLengthsTuple{});
const auto projection = detail::GenerateDefaultProjection(ThreadShape{});
return make_local_partition(tensor, thread_lengths, thread_id, projection);
}
@@ -252,21 +336,24 @@ __host__ __device__ constexpr auto make_local_partition(TensorType& tensor,
*
* \param tensor Tensor for partition.
* \param tile_shape Shapes of requested tile.
* \param block_id Block index represented as integer.
* \param projection Projection to remove selected dim from partitioning.
* slice(X) to remove, where X is dim size, Number<1>{} to keep.
* \param block_idxs Tuple of block indexes represented as integer. If slice,
* then get whole dim.
* \param projection Projection is used to remove selected dim from
* partitioning. Use `slice(X)` to remove dimension, where X is dim
* size. Use `Number<1>{}` to keep it.
* \return Tile tensor.
*/
template <typename TensorType, typename BlockShapeTuple, typename ProjectionTuple>
template <typename TensorType,
typename BlockShapeTuple,
typename BlockIdxs,
typename ProjectionTuple>
__host__ __device__ constexpr auto make_local_tile(const TensorType& tensor,
const BlockShapeTuple& tile_shape,
const index_t block_id,
const BlockIdxs& block_idxs,
const ProjectionTuple& projection)
{
static_assert(!IsNestedTuple(BlockShapeTuple{}));
constexpr bool is_default_projection =
is_same_v<ProjectionTuple, decltype(detail::GenerateDefaultProjection(BlockShapeTuple{}))>;
static_assert(!IsNestedTuple(BlockIdxs{}));
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
@@ -274,49 +361,77 @@ __host__ __device__ constexpr auto make_local_tile(const TensorType& tensor,
auto& aligned_desc = layout(tensor).GetMergedNestingDescriptor();
// TODO: Enable block_2_tile_map partitioning for non-default projection.
if constexpr(BlockShapeTuple::Size() == I2 && is_default_projection)
constexpr auto projected_tile_shape =
detail::ApplyProjection(BlockShapeTuple{}, ProjectionTuple{});
// Number of dims which are partitioned
constexpr auto dims_to_partition = detail::GetDimsToPartition(BlockIdxs{});
const auto parsed_block_idxs = detail::ReplaceSlicesWithZeros(block_idxs);
if constexpr(decltype(dims_to_partition)::Size() == I2)
{
// Optimized version for 2d tile shape [MxK]
const auto shape_with_projection_dims =
detail::CalculateShapeWithProjection(shape(tensor), projection);
// Set Value for M, N partition
const auto M = shape_with_projection_dims.At(dims_to_partition.At(I0));
const auto N = shape_with_projection_dims.At(dims_to_partition.At(I1));
constexpr auto MPerBlock = BlockShapeTuple{}.At(dims_to_partition.At(I0));
constexpr auto NPerBlock = BlockShapeTuple{}.At(dims_to_partition.At(I1));
auto m_n_desc = make_naive_tensor_descriptor_packed(make_tuple(M, N));
// Get 1D block id
const auto grid_size = detail::CalculateGridSize(shape_with_projection_dims, tile_shape);
const auto block_lengths_desc = make_naive_tensor_descriptor_packed(grid_size);
const index_t block_id_1d = block_lengths_desc.CalculateOffset(parsed_block_idxs);
// Optimized version for 2d tile shape [MxN]
const auto block_2_tile_map =
BlockToCTileMap_M00_N0_M01Adapt<BlockShapeTuple{}.At(I0),
BlockShapeTuple{}.At(I1),
remove_cvref_t<decltype(aligned_desc)>>(aligned_desc);
BlockToCTileMap_M00_N0_M01Adapt<MPerBlock,
NPerBlock,
remove_cvref_t<decltype(m_n_desc)>>(m_n_desc);
const auto block_work_idx =
block_2_tile_map.CalculateBottomIndex(make_multi_index(block_id));
block_2_tile_map.CalculateBottomIndex(make_multi_index(block_id_1d));
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * size<0>(tile_shape));
const index_t k_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * size<1>(tile_shape));
const auto offset_multi_idxs =
make_tuple(m_block_data_idx_on_grid, k_block_data_idx_on_grid);
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// Apply 0 for non partitioned dims
const auto offset_multi_idxs = generate_tuple(
[&](auto i) {
if constexpr(i == dims_to_partition.At(I0))
{
return m_block_data_idx_on_grid;
}
else if constexpr(i == dims_to_partition.At(I1))
{
return n_block_data_idx_on_grid;
}
else
{
return Number<0>{};
}
},
Number<BlockShapeTuple::Size()>{});
const auto projected_offset_multi_idxs =
detail::ApplyProjection(offset_multi_idxs, projection);
// Create new layout and tensor
const auto tile_layout =
Layout<remove_reference_t<decltype(tile_shape)>, decltype(aligned_desc)>(tile_shape,
aligned_desc);
Layout<remove_reference_t<decltype(projected_tile_shape)>, decltype(aligned_desc)>(
projected_tile_shape, aligned_desc);
auto tile_tensor =
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), tile_layout);
// Apply offsets
tile_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs));
tile_tensor.SetMultiIdxOffset(to_multi_index(projected_offset_multi_idxs));
return tile_tensor;
}
else
{
// Calculate offsets
// Sequence with data to process per block
constexpr auto projected_tile_shape =
detail::ApplyProjection(BlockShapeTuple{}, ProjectionTuple{});
using ProjectedTileShapeTuple = decltype(projected_tile_shape);
constexpr auto projected_tile_shape_seq =
generate_sequence_v2([](auto I) { return ProjectedTileShapeTuple{}.At(I); },
Number<ProjectedTileShapeTuple::Size()>{});
// Tuple with number of blocks
const auto block_lengths = detail::CalculateGridSize(shape(tensor), tile_shape, projection);
const auto block_cluster_desc_ = make_cluster_descriptor(block_lengths);
const auto block_idxs =
block_cluster_desc_.CalculateBottomIndex(make_multi_index(block_id));
const auto projected_block_idxs = detail::ApplyProjection(block_idxs, projection);
const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs(
const auto projected_block_idxs =
to_multi_index(detail::ApplyProjection(parsed_block_idxs, projection));
const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs(
projected_block_idxs, projected_tile_shape_seq, tensor.GetMultiIdxOffsets());
// Create new layout and tensor
const auto tile_layout =
@@ -338,52 +453,17 @@ __host__ __device__ constexpr auto make_local_tile(const TensorType& tensor,
*
* \param tensor Tensor for partition.
* \param tile_shape Shapes of requested tile.
* \param block_id Block index represented as integer.
* \param block_idxs Tuple of block indexes represented as integer. If slice,
* then get whole dim.
* \return Tile tensor.
*/
template <typename TensorType, typename BlockShapeTuple>
__host__ __device__ constexpr auto
make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, const index_t block_id)
template <typename TensorType, typename BlockShapeTuple, typename BlockIdxs>
__host__ __device__ constexpr auto make_local_tile(const TensorType& tensor,
const BlockShapeTuple& tile_shape,
const BlockIdxs& block_idxs)
{
const auto projection = detail::GenerateDefaultProjection(BlockShapeTuple{});
return make_local_tile(tensor, tile_shape, block_id, projection);
}
/**
* \brief Pad tensor shapes to be adjusted to tile lengths.
*
*
* \param tensor Tensor to pad.
* \param tile_lengths Tile lengths to align tensor shape.
* \return Padded tensor.
*/
template <typename TensorType, typename TileLengths>
__host__ __device__ constexpr auto pad(const TensorType& tensor, const TileLengths& tile_lengths)
{
const auto& tensor_shape = shape(tensor);
using TensorShapeType = remove_reference_t<decltype(tensor_shape)>;
auto& unrolled_desc = layout(tensor).GetUnrolledDescriptor();
// Generate sequence with ones to mark that all dims will be padded
constexpr auto do_pads_seq =
generate_sequence_v2([](auto) { return Number<1>{}; }, Number<TensorShapeType::Size()>{});
// Create descriptor with padding
auto padded_desc =
tensor_operation::device::PadTensorDescriptor(unrolled_desc, tile_lengths, do_pads_seq);
// Generate padded shape
const auto padded_shape = generate_tuple(
[&](auto i) {
const auto& dim = size<i>(tensor_shape);
const auto& tile_length = size<i>(tile_lengths);
return ck::math::integer_divide_ceil(dim, tile_length) * tile_length;
},
Number<TileLengths::Size()>{});
// Create layout and tensor
const auto padded_layout =
Layout<decltype(padded_shape), decltype(padded_desc)>(padded_shape, padded_desc);
auto partition_tensor =
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), padded_layout);
partition_tensor.SetMultiIdxOffset(tensor.GetMultiIdxOffsets());
return partition_tensor;
return make_local_tile(tensor, tile_shape, block_idxs, projection);
}
} // namespace wrapper

View File

@@ -13,8 +13,11 @@
#include "ck/utility/amd_address_space.hpp"
#include "ck/utility/multi_index.hpp"
// Disable from doxygen docs generation
/// @cond INTERNAL
namespace ck {
namespace wrapper {
/// @endcond
/**
* \brief Memory type, allowed members:
@@ -27,7 +30,7 @@ namespace wrapper {
using MemoryTypeEnum = AddressSpaceEnum;
// Disable from doxygen docs generation
/// @cond
/// @cond INTERNAL
// forward declarations
template <typename Shape, typename UnrolledDescriptorType>
struct Layout;