mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
added new tensor copy operator
This commit is contained in:
@@ -24,7 +24,7 @@ template <index_t BlockSize,
|
||||
class DstDesc,
|
||||
class SliceLengths,
|
||||
class SubLengths,
|
||||
class DataClusterLengths,
|
||||
class ThreadClusterLengths,
|
||||
class ThreadClusterArrangeOrder,
|
||||
class SrcAccessOrder,
|
||||
class DstAccessOrder,
|
||||
@@ -65,7 +65,8 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
// check NDim consistency
|
||||
static_assert(nDim == SrcDesc::GetNumOfDimension() &&
|
||||
nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::GetSize() &&
|
||||
nDim == SubLengths::GetSize() && nDim == DataClusterLengths::GetSize() &&
|
||||
nDim == SubLengths::GetSize() &&
|
||||
nDim == ThreadClusterLengths::GetSize() &&
|
||||
nDim == ThreadClusterArrangeOrder::GetSize() &&
|
||||
nDim == SrcAccessOrder::GetSize() && nDim == DstAccessOrder::GetSize(),
|
||||
"wrong");
|
||||
@@ -78,13 +79,13 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
|
||||
// thread cluster
|
||||
constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor_packed(
|
||||
DataClusterLengths::ReorderGivenNew2Old(ThreadClusterArrangeOrder{}));
|
||||
ThreadClusterLengths::ReorderGivenNew2Old(ThreadClusterArrangeOrder{}));
|
||||
|
||||
// BlockSize
|
||||
static_assert(BlockSize == thread_cluster_desc.GetElementSize(), "wrong! BlockSize");
|
||||
|
||||
// divide work
|
||||
constexpr auto data_per_cluster_per_dims = SubLengths{} * DataClusterLengths{};
|
||||
constexpr auto data_per_cluster_per_dims = SubLengths{} * ThreadClusterLengths{};
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto IDim) {
|
||||
static_assert(SliceLengths::Get(IDim) % SubLengths::Get(IDim) == 0,
|
||||
@@ -160,9 +161,9 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
mThreadDstPartialOffsets, math::plus<index_t>{}, static_cast<index_t>(0));
|
||||
}
|
||||
|
||||
__device__ static constexpr index_t GetRegisterClipboardSize()
|
||||
__device__ static constexpr index_t GetRegisterBufferSize()
|
||||
{
|
||||
constexpr auto repeat_lengths = SliceLengths{} / (SubLengths{} * DataClusterLengths{});
|
||||
constexpr auto repeat_lengths = SliceLengths{} / (SubLengths{} * ThreadClusterLengths{});
|
||||
|
||||
constexpr auto thread_tensor_desc =
|
||||
make_ConstantTensorDescriptor_packed(SubLengths{} * repeat_lengths);
|
||||
@@ -170,14 +171,15 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
return thread_tensor_desc.GetElementSpace();
|
||||
}
|
||||
|
||||
__device__ void RunLoadRegisterClipboard(const Float* __restrict__ p_src,
|
||||
Float* __restrict__ p_clipboard) const
|
||||
__device__ void RunLoadRegisterBuffer(const Float* __restrict__ p_src,
|
||||
Float* __restrict__ p_Buffer) const
|
||||
{
|
||||
constexpr auto thread_sub_tensor_lengths = SubLengths{};
|
||||
|
||||
constexpr auto data_per_cluster_per_dims = thread_sub_tensor_lengths * DataClusterLengths{};
|
||||
constexpr auto data_per_cluster_per_dims =
|
||||
thread_sub_tensor_lengths * ThreadClusterLengths{};
|
||||
|
||||
constexpr auto repeat_lengths = SliceLengths{} / (SubLengths{} * DataClusterLengths{});
|
||||
constexpr auto repeat_lengths = SliceLengths{} / (SubLengths{} * ThreadClusterLengths{});
|
||||
|
||||
constexpr auto thread_tensor_desc =
|
||||
make_ConstantTensorDescriptor_packed(thread_sub_tensor_lengths * repeat_lengths);
|
||||
@@ -187,25 +189,24 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
constexpr auto src_thread_data_multi_id_begin =
|
||||
repeat_multi_id * data_per_cluster_per_dims;
|
||||
|
||||
constexpr auto clipboard_data_multi_id_begin =
|
||||
repeat_multi_id * thread_sub_tensor_lengths;
|
||||
constexpr auto Buffer_data_multi_id_begin = repeat_multi_id * thread_sub_tensor_lengths;
|
||||
|
||||
constexpr index_t src_offset =
|
||||
SrcDesc::GetOffsetFromMultiIndex(src_thread_data_multi_id_begin);
|
||||
|
||||
constexpr index_t clipboard_offset =
|
||||
thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin);
|
||||
constexpr index_t Buffer_offset =
|
||||
thread_tensor_desc.GetOffsetFromMultiIndex(Buffer_data_multi_id_begin);
|
||||
#else
|
||||
ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id) {
|
||||
const auto src_thread_data_multi_id_begin = repeat_multi_id * data_per_cluster_per_dims;
|
||||
|
||||
const auto clipboard_data_multi_id_begin = repeat_multi_id * thread_sub_tensor_lengths;
|
||||
const auto Buffer_data_multi_id_begin = repeat_multi_id * thread_sub_tensor_lengths;
|
||||
|
||||
const index_t src_offset =
|
||||
SrcDesc::GetOffsetFromMultiIndex(src_thread_data_multi_id_begin);
|
||||
|
||||
const index_t clipboard_offset =
|
||||
thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin);
|
||||
const index_t Buffer_offset =
|
||||
thread_tensor_desc.GetOffsetFromMultiIndex(Buffer_data_multi_id_begin);
|
||||
#endif
|
||||
|
||||
// By position the origin of the per-thread window at the point, where multi-index
|
||||
@@ -219,7 +220,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
p_src + src_offset + mThreadSrcOffset,
|
||||
make_zero_array<index_t, nDim>(),
|
||||
thread_tensor_desc,
|
||||
p_clipboard + clipboard_offset,
|
||||
p_Buffer + Buffer_offset,
|
||||
make_zero_array<index_t, nDim>(),
|
||||
thread_sub_tensor_lengths,
|
||||
SrcAccessOrder{},
|
||||
@@ -227,38 +228,38 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
});
|
||||
}
|
||||
|
||||
__device__ void RunStoreRegisterClipboard(const Float* __restrict__ p_clipboard,
|
||||
Float* __restrict__ p_dst) const
|
||||
__device__ void RunStoreRegisterBuffer(const Float* __restrict__ p_Buffer,
|
||||
Float* __restrict__ p_dst) const
|
||||
{
|
||||
constexpr auto thread_sub_tensor_lengths = SubLengths{};
|
||||
|
||||
constexpr auto data_per_cluster_per_dims = thread_sub_tensor_lengths * DataClusterLengths{};
|
||||
constexpr auto data_per_cluster_per_dims =
|
||||
thread_sub_tensor_lengths * ThreadClusterLengths{};
|
||||
|
||||
constexpr auto repeat_lengths = SliceLengths{} / (SubLengths{} * DataClusterLengths{});
|
||||
constexpr auto repeat_lengths = SliceLengths{} / (SubLengths{} * ThreadClusterLengths{});
|
||||
|
||||
constexpr auto thread_tensor_desc =
|
||||
make_ConstantTensorDescriptor_packed(thread_sub_tensor_lengths * repeat_lengths);
|
||||
|
||||
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1
|
||||
static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id) {
|
||||
constexpr auto clipboard_data_multi_id_begin =
|
||||
repeat_multi_id * thread_sub_tensor_lengths;
|
||||
constexpr auto Buffer_data_multi_id_begin = repeat_multi_id * thread_sub_tensor_lengths;
|
||||
|
||||
constexpr auto dst_data_multi_id_begin = repeat_multi_id * data_per_cluster_per_dims;
|
||||
|
||||
constexpr index_t clipboard_offset =
|
||||
thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin);
|
||||
constexpr index_t Buffer_offset =
|
||||
thread_tensor_desc.GetOffsetFromMultiIndex(Buffer_data_multi_id_begin);
|
||||
|
||||
constexpr index_t dst_offset =
|
||||
DstDesc::GetOffsetFromMultiIndex(dst_data_multi_id_begin);
|
||||
#else
|
||||
ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id) {
|
||||
const auto clipboard_data_multi_id_begin = repeat_multi_id * thread_sub_tensor_lengths;
|
||||
const auto Buffer_data_multi_id_begin = repeat_multi_id * thread_sub_tensor_lengths;
|
||||
|
||||
const auto dst_data_multi_id_begin = repeat_multi_id * data_per_cluster_per_dims;
|
||||
|
||||
const index_t clipboard_offset =
|
||||
thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin);
|
||||
const index_t Buffer_offset =
|
||||
thread_tensor_desc.GetOffsetFromMultiIndex(Buffer_data_multi_id_begin);
|
||||
|
||||
const index_t dst_offset = DstDesc::GetOffsetFromMultiIndex(dst_data_multi_id_begin);
|
||||
#endif
|
||||
@@ -271,7 +272,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
// If in the future, you want to enable SubLengths > 1 at the merged dimension,
|
||||
// special care in implementation is needed
|
||||
threadwise_generic_tensor_slice_copy_v1(thread_tensor_desc,
|
||||
p_clipboard + clipboard_offset,
|
||||
p_Buffer + Buffer_offset,
|
||||
make_zero_array<index_t, nDim>(),
|
||||
DstDesc{},
|
||||
p_dst + dst_offset + mThreadDstOffset,
|
||||
@@ -284,10 +285,10 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
|
||||
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
|
||||
{
|
||||
Float p_clipboard[GetRegisterClipboardSize()];
|
||||
Float p_Buffer[GetRegisterBufferSize()];
|
||||
|
||||
RunLoadRegisterClipboard(p_src, p_clipboard);
|
||||
RunStoreRegisterClipboard(p_clipboard, p_dst);
|
||||
RunLoadRegisterBuffer(p_src, p_Buffer);
|
||||
RunStoreRegisterBuffer(p_Buffer, p_dst);
|
||||
}
|
||||
|
||||
// When moving the slicing windows along a merged dimension, if the strides of the
|
||||
@@ -382,24 +383,30 @@ template <index_t BlockSize,
|
||||
class DstCoordinate,
|
||||
class SliceLengths,
|
||||
class SubLengths,
|
||||
class DataClusterLengths,
|
||||
class ThreadClusterLengths,
|
||||
class ThreadClusterArrangeOrder>
|
||||
struct BlockwiseGenericTensorSliceCopy_v2
|
||||
{
|
||||
using ThreadwiseCopy = ThreadwiseGenericTensorSliceCopy_v2<TData,
|
||||
SrcDesc,
|
||||
DstDesc,
|
||||
SrcCoordinate,
|
||||
DstCoordinate,
|
||||
SubLengths>;
|
||||
|
||||
static constexpr index_t nDim = SrcDesc::GetNumOfDimension();
|
||||
|
||||
__device__ constexpr BlockwiseGenericTensorSliceCopy_v2(SrcCoordinate src_block_slice_origin,
|
||||
DstCoordinate dst_block_slice_origin)
|
||||
{
|
||||
static_assert(nDim == SrcDesc::GetNumOfDimension() &&
|
||||
nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::GetSize() &&
|
||||
nDim == SubLengths::GetSize() &&
|
||||
nDim == ThreadClusterLengths::GetSize() &&
|
||||
nDim == ThreadClusterArrangeOrder::GetSize(),
|
||||
"wrong! nDim not consistent");
|
||||
|
||||
static_assert(is_same<SliceLengths, decltype(SubLengths{} * ThreadClusterLengths{})>{},
|
||||
"wrong! threads should be mapped to cover entire slicing window");
|
||||
|
||||
constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor_packed(
|
||||
DataClusterLengths::ReorderGivenNew2Old(ThreadClusterArrangeOrder{}));
|
||||
ThreadClusterLengths::ReorderGivenNew2Old(ThreadClusterArrangeOrder{}));
|
||||
|
||||
static_assert(BlockSize == thread_cluster_desc.GetElementSize(),
|
||||
"wrong! BlockSize not consistent with ThreadClusterLengths");
|
||||
|
||||
const auto thread_cluster_multi_id =
|
||||
thread_cluster_desc.GetMultiIndexFrom1dIndex(get_thread_local_1d_id());
|
||||
@@ -409,43 +416,66 @@ struct BlockwiseGenericTensorSliceCopy_v2
|
||||
|
||||
const auto thread_data_multi_id_begin = data_cluster_multi_id * SubLengths{};
|
||||
|
||||
mThreadwiseCopy.SetSrcSliceOrigin(src_block_slice_origin + thread_data_multi_id_begin);
|
||||
mThreadwiseCopy.SetDstSliceOrigin(dst_block_slice_origin + thread_data_multi_id_begin);
|
||||
mThreadwiseLoad.SetSrcSliceOrigin(src_block_slice_origin + thread_data_multi_id_begin);
|
||||
mThreadwiseLoad.SetDstSliceOrigin(make_zero_array<index_t, nDim>());
|
||||
|
||||
mThreadwiseStore.SetSrcSliceOrigin(make_zero_array<index_t, nDim>());
|
||||
mThreadwiseStore.SetDstSliceOrigin(dst_block_slice_origin + thread_data_multi_id_begin);
|
||||
}
|
||||
|
||||
__device__ static constexpr index_t GetRegisterBufferSize()
|
||||
{
|
||||
return RegisterBufferDesc::GetElementSpace();
|
||||
}
|
||||
|
||||
__device__ void RunLoadRegisterBuffer(const TData* p_src, TData* p_buffer) const
|
||||
{
|
||||
mThreadwiseLoad.Run(p_src, p_buffer);
|
||||
}
|
||||
|
||||
__device__ void RunStoreRegisterBuffer(const TData* p_buffer, TData* p_dst) const
|
||||
{
|
||||
mThreadwiseStore.Run(p_buffer, p_dst);
|
||||
}
|
||||
|
||||
__device__ void Run(const TData* p_src, TData* p_dst) const
|
||||
{
|
||||
mThreadwiseCopy.Run(p_src, p_dst);
|
||||
}
|
||||
TData p_buffer[GetRegisterBufferSize()];
|
||||
|
||||
template <class SrcMergedDimSubLengthsHack, class DstMergedDimSubLengthsHack>
|
||||
__device__ void Run_hack(const TData* p_src,
|
||||
TData* p_dst,
|
||||
SrcMergedDimSubLengthsHack,
|
||||
DstMergedDimSubLengthsHack) const
|
||||
{
|
||||
// hacks to isolate merged dimension from normal dimensions, and calculate their offset
|
||||
// seperately
|
||||
// SrcMergedDimSliceLengthsHack has entry same as SliceLengths on src merged dimensions,
|
||||
// but 1 on normal dimensions;
|
||||
// SrcNormalDimSliceLengthsHack has entry same as SliceLengths on src normal dimensions,
|
||||
// but 1 on merged dimensions;
|
||||
mThreadwiseCopy.Run_hack(
|
||||
p_src, p_dst, SrcMergedDimSubLengthsHack{}, DstMergedDimSubLengthsHack{});
|
||||
mThreadwiseLoad.Run(p_src, p_buffer);
|
||||
mThreadwiseStore.Run(p_buffer, p_dst);
|
||||
}
|
||||
|
||||
__device__ void MoveSrcSlicingWindow(Array<index_t, nDim> step_sizes, bool positive_direction)
|
||||
{
|
||||
mThreadwiseCopy.MoveSrcSlicingWindow(step_sizes, positive_direction);
|
||||
mThreadwiseLoad.MoveSrcSlicingWindow(step_sizes, positive_direction);
|
||||
}
|
||||
|
||||
__device__ void MoveDstSlicingWindow(Array<index_t, nDim> step_sizes, bool positive_direction)
|
||||
{
|
||||
mThreadwiseCopy.MoveDstSlicingWindow(step_sizes, positive_direction);
|
||||
mThreadwiseStore.MoveDstSlicingWindow(step_sizes, positive_direction);
|
||||
}
|
||||
|
||||
// private:
|
||||
ThreadwiseCopy mThreadwiseCopy;
|
||||
private:
|
||||
using RegisterBufferDesc = decltype(make_ConstantTensorDescriptor_packed(SubLengths{}));
|
||||
|
||||
using ThreadwiseLoad =
|
||||
ThreadwiseGenericTensorSliceCopy_v2<TData,
|
||||
SrcDesc,
|
||||
RegisterBufferDesc,
|
||||
SrcCoordinate,
|
||||
NormalTensorCoordinate<RegisterBufferDesc>,
|
||||
SubLengths>;
|
||||
|
||||
using ThreadwiseStore =
|
||||
ThreadwiseGenericTensorSliceCopy_v2<TData,
|
||||
RegisterBufferDesc,
|
||||
DstDesc,
|
||||
NormalTensorCoordinate<RegisterBufferDesc>,
|
||||
DstCoordinate,
|
||||
SubLengths>;
|
||||
ThreadwiseLoad mThreadwiseLoad;
|
||||
ThreadwiseStore mThreadwiseStore;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -138,47 +138,17 @@ struct ThreadwiseGenericTensorSliceCopy_v2
|
||||
mDstSliceOrigin = dst_slice_origin;
|
||||
}
|
||||
|
||||
__device__ void Run(const TData* p_src, TData* p_dst) const
|
||||
template <class TDesc, class Seq>
|
||||
struct IsolateMergedDimSliceLengthsHack
|
||||
{
|
||||
constexpr auto buffer_desc = make_ConstantTensorDescriptor_packed(SliceLengths{});
|
||||
template <class IDim>
|
||||
__device__ constexpr index_t operator()(IDim idim) const
|
||||
{
|
||||
return TDesc::ContainMultipleOriginalDimensions(idim) ? Seq{}[idim] : 1;
|
||||
}
|
||||
};
|
||||
|
||||
TData p_buffer_[buffer_desc.GetElementSpace()];
|
||||
TData* p_buffer = p_buffer_;
|
||||
|
||||
#if 0
|
||||
static_ford<SliceLengths>{}([&](auto data_id) {
|
||||
p_buffer[buffer_desc.GetOffsetFromMultiIndex(data_id)] =
|
||||
p_src[(mSrcSliceOrigin + data_id).GetOffset()];
|
||||
});
|
||||
|
||||
static_ford<SliceLengths>{}([&](auto data_id) {
|
||||
p_dst[(mDstSliceOrigin + data_id).GetOffset()] =
|
||||
p_buffer[buffer_desc.GetOffsetFromMultiIndex(data_id)];
|
||||
});
|
||||
#elif 1
|
||||
auto src_slice_origin = mSrcSliceOrigin;
|
||||
auto dst_slice_origin = mDstSliceOrigin;
|
||||
|
||||
const TData* p_src_tmp = p_src + src_slice_origin.RepositionOrigin();
|
||||
TData* p_dst_tmp = p_dst + dst_slice_origin.RepositionOrigin();
|
||||
|
||||
static_ford<SliceLengths>{}([&](auto data_id) {
|
||||
p_buffer[buffer_desc.GetOffsetFromMultiIndex(data_id)] =
|
||||
p_src_tmp[(src_slice_origin + data_id).GetOffset()];
|
||||
});
|
||||
|
||||
static_ford<SliceLengths>{}([&](auto data_id) {
|
||||
p_dst_tmp[(dst_slice_origin + data_id).GetOffset()] =
|
||||
p_buffer[buffer_desc.GetOffsetFromMultiIndex(data_id)];
|
||||
});
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class SrcMergedDimSliceLengthsHack, class DstMergedDimSliceLengthsHack>
|
||||
__device__ void Run_hack(const TData* p_src,
|
||||
TData* p_dst,
|
||||
SrcMergedDimSliceLengthsHack,
|
||||
DstMergedDimSliceLengthsHack) const
|
||||
__device__ void Run(const TData* p_src, TData* p_dst) const
|
||||
{
|
||||
constexpr auto buffer_desc = make_ConstantTensorDescriptor_packed(SliceLengths{});
|
||||
|
||||
@@ -191,6 +161,10 @@ struct ThreadwiseGenericTensorSliceCopy_v2
|
||||
// but 1 on normal dimensions;
|
||||
// SrcNormalDimSliceLengthsHack has entry same as SliceLengths on src normal dimensions,
|
||||
// but 1 on merged dimensions;
|
||||
using SrcMergedDimSliceLengthsHack =
|
||||
typename sequence_gen<SliceLengths::GetSize(),
|
||||
IsolateMergedDimSliceLengthsHack<SrcDesc, SliceLengths>>::type;
|
||||
|
||||
using SrcNormalDimSliceLengthsHack =
|
||||
decltype((SliceLengths{} + Number<1>{}) - SrcMergedDimSliceLengthsHack{});
|
||||
|
||||
@@ -216,6 +190,10 @@ struct ThreadwiseGenericTensorSliceCopy_v2
|
||||
// but 1 on normal dimensions;
|
||||
// DstNormalDimSliceLengthsHack has entry same as SliceLengths on dst normal dimensions,
|
||||
// but 1 on merged dimensions;
|
||||
using DstMergedDimSliceLengthsHack =
|
||||
typename sequence_gen<SliceLengths::GetSize(),
|
||||
IsolateMergedDimSliceLengthsHack<DstDesc, SliceLengths>>::type;
|
||||
|
||||
using DstNormalDimSliceLengthsHack =
|
||||
decltype((SliceLengths{} + Number<1>{}) - DstMergedDimSliceLengthsHack{});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user