mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
added ThreadwiseGenericTensorSliceCopy_v2r1
This commit is contained in:
@@ -412,7 +412,13 @@ template <index_t BlockSize,
|
||||
class SliceLengths,
|
||||
class SubLengths,
|
||||
class ThreadClusterLengths,
|
||||
class ThreadClusterArrangeOrder>
|
||||
class ThreadClusterArrangeOrder,
|
||||
class SrcDimAccessOrder,
|
||||
class DstDimAccessOrder,
|
||||
index_t SrcVectorAccessDim,
|
||||
index_t DstVectorAccessDim,
|
||||
index_t SrcDataPerAccess,
|
||||
index_t DstDataPerAccess>
|
||||
struct BlockwiseGenericTensorSliceCopy_v2
|
||||
{
|
||||
static constexpr index_t nDim = SrcDesc::GetNumOfDimension();
|
||||
@@ -496,6 +502,7 @@ struct BlockwiseGenericTensorSliceCopy_v2
|
||||
private:
|
||||
using RegisterBufferDesc = decltype(make_ConstantTensorDescriptor_packed(SubLengths{}));
|
||||
|
||||
#if 0
|
||||
using ThreadwiseLoad =
|
||||
ThreadwiseGenericTensorSliceCopy_v2<SrcDesc,
|
||||
RegisterBufferDesc,
|
||||
@@ -509,6 +516,33 @@ struct BlockwiseGenericTensorSliceCopy_v2
|
||||
NormalTensorCoordinate<RegisterBufferDesc>,
|
||||
DstCoordinate,
|
||||
SubLengths>;
|
||||
#else
|
||||
using ThreadwiseLoad =
|
||||
ThreadwiseGenericTensorSliceCopy_v2r1<SrcDesc,
|
||||
RegisterBufferDesc,
|
||||
SrcCoordinate,
|
||||
NormalTensorCoordinate<RegisterBufferDesc>,
|
||||
SubLengths,
|
||||
SrcDimAccessOrder,
|
||||
SrcDimAccessOrder,
|
||||
SrcVectorAccessDim,
|
||||
SrcVectorAccessDim,
|
||||
SrcDataPerAccess,
|
||||
1>;
|
||||
|
||||
using ThreadwiseStore =
|
||||
ThreadwiseGenericTensorSliceCopy_v2r1<RegisterBufferDesc,
|
||||
DstDesc,
|
||||
NormalTensorCoordinate<RegisterBufferDesc>,
|
||||
DstCoordinate,
|
||||
SubLengths,
|
||||
DstDimAccessOrder,
|
||||
DstDimAccessOrder,
|
||||
DstVectorAccessDim,
|
||||
DstVectorAccessDim,
|
||||
1,
|
||||
DstDataPerAccess>;
|
||||
#endif
|
||||
ThreadwiseLoad mThreadwiseLoad;
|
||||
ThreadwiseStore mThreadwiseStore;
|
||||
};
|
||||
|
||||
@@ -18,6 +18,10 @@
|
||||
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1
|
||||
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0
|
||||
#endif
|
||||
|
||||
namespace ck {
|
||||
|
||||
// This threadwise copy allow vector access of src and dst.
|
||||
@@ -590,5 +594,313 @@ struct ThreadwiseGenericTensorSliceCopy_v2
|
||||
DstCoordinate mDstSliceOrigin;
|
||||
};
|
||||
|
||||
#if 1
|
||||
// This threadwise copy allow vector access of src and dst.
|
||||
// It allows the dimensions of vector access to be different on src and dst.
|
||||
// It also allows the vector size to be different on src and dst.
|
||||
// It also allows order of access to be different on src and dst.
|
||||
// It use register as buffer to hold all data moving from src to dst.
|
||||
// It is designed for copying small amount of data, and src and dst are
|
||||
// device memory or LDS.
|
||||
// When copying large amout of data, let's hope compiler will reduce register
|
||||
// used for the buffer.
|
||||
template <class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcCoordinate,
|
||||
class DstCoordinate,
|
||||
class SliceLengths,
|
||||
class SrcDimAccessOrder,
|
||||
class DstDimAccessOrder,
|
||||
index_t SrcVectorAccessDim,
|
||||
index_t DstVectorAccessDim,
|
||||
index_t SrcDataPerAccess,
|
||||
index_t DstDataPerAccess>
|
||||
struct ThreadwiseGenericTensorSliceCopy_v2r1
|
||||
{
|
||||
static constexpr index_t nDim = SliceLengths::GetSize();
|
||||
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v2r1(SrcCoordinate src_slice_origin,
|
||||
DstCoordinate dst_slice_origin)
|
||||
: mSrcSliceOrigin(src_slice_origin), mDstSliceOrigin(dst_slice_origin)
|
||||
{
|
||||
}
|
||||
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v2r1()
|
||||
: ThreadwiseGenericTensorSliceCopy_v2r1(make_zero_array<index_t, nDim>(),
|
||||
make_zero_array<index_t, nDim>())
|
||||
{
|
||||
}
|
||||
|
||||
__device__ void SetSrcSliceOrigin(SrcCoordinate src_slice_origin)
|
||||
{
|
||||
mSrcSliceOrigin = src_slice_origin;
|
||||
}
|
||||
|
||||
__device__ void SetDstSliceOrigin(DstCoordinate dst_slice_origin)
|
||||
{
|
||||
mDstSliceOrigin = dst_slice_origin;
|
||||
}
|
||||
|
||||
template <class TDesc, class Lengths>
|
||||
struct IsolateMergedDimLengths
|
||||
{
|
||||
template <class IDim>
|
||||
__device__ constexpr index_t operator()(IDim idim) const
|
||||
{
|
||||
return TDesc::ContainMultipleOriginalDimensions(idim) ? Lengths{}[idim] : 1;
|
||||
}
|
||||
};
|
||||
|
||||
template <class TData>
|
||||
__device__ void Run(const TData* p_src, TData* p_dst) const
|
||||
{
|
||||
constexpr auto buffer_desc = make_ConstantTensorDescriptor_packed(SliceLengths{});
|
||||
|
||||
TData p_buffer_[buffer_desc.GetElementSpace()];
|
||||
TData* p_buffer = p_buffer_;
|
||||
|
||||
// copy data from src into buffer
|
||||
{
|
||||
using src_vector_t = typename vector_type<TData, SrcDataPerAccess>::MemoryType;
|
||||
|
||||
constexpr auto src_vector_access_dim = Number<SrcVectorAccessDim>{};
|
||||
constexpr auto src_data_per_access = Number<SrcDataPerAccess>{};
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths::Modify(
|
||||
src_vector_access_dim,
|
||||
SliceLengths::Get(src_vector_access_dim) / src_data_per_access);
|
||||
|
||||
// Offset w.r.t merged dimensions need to be calculated at run-time. Offset w.r.t
|
||||
// normal dimensions is known at compile time.
|
||||
// Below is a hack to isolate merged dimension id from normal dimension id, so the
|
||||
// corresponding offset can be calculated seperately at run-time and compile-time.
|
||||
// src_merged_dim_access_lengths has the same value as src_access_lengths on src's
|
||||
// merged dimensions, and has value = 1 on normal dimensions;
|
||||
// src_merged_dim_access_lengths has the same value as src_access_lengths on src's
|
||||
// normal dimensions, and has value = 1 on merged dimensions;
|
||||
constexpr auto src_merged_dim_access_lengths = typename sequence_gen<
|
||||
nDim,
|
||||
IsolateMergedDimLengths<SrcDesc, decltype(src_access_lengths)>>::type{};
|
||||
|
||||
constexpr auto src_normal_dim_access_lengths =
|
||||
src_access_lengths + Number<1>{} - src_merged_dim_access_lengths;
|
||||
|
||||
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1
|
||||
// offset w.r.t. merged dimension need to be computed at run-time
|
||||
static_ford<decltype(src_merged_dim_access_lengths), SrcDimAccessOrder>{}([&](
|
||||
auto src_merged_dim_access_id_) {
|
||||
|
||||
constexpr auto src_merged_dim_access_id = decltype(src_merged_dim_access_id_){};
|
||||
|
||||
constexpr auto src_merged_dim_data_id = src_merged_dim_access_id.Modify(
|
||||
src_vector_access_dim,
|
||||
src_merged_dim_access_id[src_vector_access_dim] * src_data_per_access);
|
||||
|
||||
const TData* p_src_tmp =
|
||||
p_src + (mSrcSliceOrigin + src_merged_dim_data_id).GetOffset();
|
||||
|
||||
// offset w.r.t. normal dimension can be computed at compile-time
|
||||
static_ford<decltype(src_normal_dim_access_lengths), SrcDimAccessOrder>{}([&](
|
||||
auto src_normal_dim_access_id_) {
|
||||
|
||||
constexpr auto src_normal_dim_access_id = decltype(src_normal_dim_access_id_){};
|
||||
|
||||
constexpr auto src_normal_dim_data_id = src_normal_dim_access_id.Modify(
|
||||
src_vector_access_dim,
|
||||
src_normal_dim_access_id[src_vector_access_dim] * src_data_per_access);
|
||||
|
||||
constexpr index_t src_normal_offset =
|
||||
SrcDesc::GetOffsetFromMultiIndex(src_normal_dim_data_id);
|
||||
|
||||
// load vector from src
|
||||
const src_vector_t vector_data =
|
||||
*reinterpret_cast<const src_vector_t*>(&p_src_tmp[src_normal_offset]);
|
||||
|
||||
// unpack vector into buffer
|
||||
static_for<0, SrcDataPerAccess, 1>{}([&](auto i) {
|
||||
constexpr auto scalar_id =
|
||||
typename uniform_sequence_gen<nDim, 0>::type{}.Modify(
|
||||
src_vector_access_dim, i);
|
||||
|
||||
constexpr index_t buffer_offset = buffer_desc.GetOffsetFromMultiIndex(
|
||||
src_merged_dim_data_id + src_normal_dim_data_id + scalar_id);
|
||||
|
||||
constexpr index_t buffer_offset =
|
||||
buffer_desc.GetOffsetFromMultiIndex(src_data_begin_id + scalar_id);
|
||||
|
||||
p_buffer[buffer_offset] = reinterpret_cast<const TData*>(&vector_data)[i];
|
||||
});
|
||||
});
|
||||
});
|
||||
#else
|
||||
ford<decltype(src_merged_dim_access_lengths), SrcDimAccessOrder>{}([&](
|
||||
auto src_merged_dim_access_id) {
|
||||
|
||||
auto src_merged_dim_data_id = src_merged_dim_access_id;
|
||||
src_merged_dim_data_id(src_vector_access_dim) =
|
||||
src_merged_dim_access_id[src_vector_access_dim] * src_data_per_access;
|
||||
|
||||
const TData* p_src_tmp =
|
||||
p_src + (mSrcSliceOrigin + src_merged_dim_data_id).GetOffset();
|
||||
|
||||
// these should be compile-time known
|
||||
ford<decltype(src_normal_dim_access_lengths), SrcDimAccessOrder>{}([&](
|
||||
auto src_normal_dim_access_id) {
|
||||
|
||||
auto src_normal_dim_data_id = src_normal_dim_access_id;
|
||||
src_normal_dim_data_id(src_vector_access_dim) =
|
||||
src_normal_dim_access_id[src_vector_access_dim] * src_data_per_access;
|
||||
|
||||
const index_t src_normal_offset =
|
||||
SrcDesc::GetOffsetFromMultiIndex(src_normal_dim_data_id);
|
||||
|
||||
// load vector from src
|
||||
const src_vector_t vector_data =
|
||||
*reinterpret_cast<const src_vector_t*>(&p_src_tmp[src_normal_offset]);
|
||||
|
||||
// unpack vector into buffer
|
||||
for(index_t i = 0; i < SrcDataPerAccess; ++i)
|
||||
{
|
||||
auto scalar_id = make_zero_array<index_t, nDim>();
|
||||
scalar_id(src_vector_access_dim) = i;
|
||||
|
||||
const index_t buffer_offset = buffer_desc.GetOffsetFromMultiIndex(
|
||||
src_merged_dim_data_id + src_normal_dim_data_id + scalar_id);
|
||||
|
||||
p_buffer[buffer_offset] = reinterpret_cast<const TData*>(&vector_data)[i];
|
||||
}
|
||||
});
|
||||
});
|
||||
#endif
|
||||
}
|
||||
|
||||
// copy data from buffer into dst
|
||||
{
|
||||
using dst_vector_t = typename vector_type<TData, DstDataPerAccess>::MemoryType;
|
||||
|
||||
constexpr auto dst_vector_access_dim = Number<DstVectorAccessDim>{};
|
||||
constexpr auto dst_data_per_access = Number<DstDataPerAccess>{};
|
||||
|
||||
constexpr auto dst_access_lengths = SliceLengths::Modify(
|
||||
dst_vector_access_dim,
|
||||
SliceLengths::Get(dst_vector_access_dim) / dst_data_per_access);
|
||||
|
||||
constexpr auto dst_merged_dim_access_lengths = typename sequence_gen<
|
||||
nDim,
|
||||
IsolateMergedDimLengths<DstDesc, decltype(dst_access_lengths)>>::type{};
|
||||
|
||||
constexpr auto dst_normal_dim_access_lengths =
|
||||
dst_access_lengths + Number<1>{} - dst_merged_dim_access_lengths;
|
||||
|
||||
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1
|
||||
// offset w.r.t. merged dimension need to be computed at run-time
|
||||
static_ford<decltype(dst_merged_dim_access_lengths), DstDimAccessOrder>{}([&](
|
||||
auto dst_merged_dim_access_id_) {
|
||||
|
||||
constexpr auto dst_merged_dim_access_id = decltype(dst_merged_dim_access_id_){};
|
||||
|
||||
constexpr auto dst_merged_dim_data_id = dst_merged_dim_access_id.Modify(
|
||||
dst_vector_access_dim,
|
||||
dst_merged_dim_access_id[dst_vector_access_dim] * dst_data_per_access);
|
||||
|
||||
TData* p_dst_tmp = p_dst + (mDstSliceOrigin + dst_merged_dim_data_id).GetOffset();
|
||||
|
||||
// offset w.r.t. normal dimension can be computed at compile-time
|
||||
static_ford<decltype(dst_normal_dim_access_lengths), DstDimAccessOrder>{}([&](
|
||||
auto dst_normal_dim_access_id_) {
|
||||
constexpr auto dst_normal_dim_access_id = decltype(dst_normal_dim_access_id_){};
|
||||
|
||||
constexpr auto dst_normal_dim_data_id = dst_normal_dim_access_id.Modify(
|
||||
dst_vector_access_dim,
|
||||
dst_normal_dim_access_id[dst_vector_access_dim] * dst_data_per_access);
|
||||
|
||||
dst_vector_t vector_data;
|
||||
|
||||
// pack vector from buffer
|
||||
static_for<0, DstDataPerAccess, 1>{}([&](auto i) {
|
||||
constexpr auto scalar_id =
|
||||
typename uniform_sequence_gen<nDim, 0>::type{}.Modify(
|
||||
dst_vector_access_dim, i);
|
||||
|
||||
constexpr index_t buffer_offset = buffer_desc.GetOffsetFromMultiIndex(
|
||||
dst_merged_dim_data_id + dst_normal_dim_data_id + scalar_id);
|
||||
|
||||
reinterpret_cast<TData*>(&vector_data)[i] = p_buffer[buffer_offset];
|
||||
});
|
||||
|
||||
constexpr index_t dst_normal_offset =
|
||||
DstDesc::GetOffsetFromMultiIndex(dst_normal_dim_data_id);
|
||||
|
||||
// write vector into dst
|
||||
*reinterpret_cast<dst_vector_t*>(&p_dst_tmp[dst_normal_offset]) = vector_data;
|
||||
});
|
||||
});
|
||||
#else
|
||||
// offset w.r.t. merged dimension need to be computed at run-time
|
||||
ford<decltype(dst_merged_dim_access_lengths), DstDimAccessOrder>{}([&](
|
||||
auto dst_merged_dim_access_id) {
|
||||
|
||||
auto dst_merged_dim_data_id = dst_merged_dim_access_id;
|
||||
dst_merged_dim_data_id(dst_vector_access_dim) =
|
||||
dst_merged_dim_access_id[dst_vector_access_dim] * dst_data_per_access;
|
||||
|
||||
TData* p_dst_tmp = p_dst + (mDstSliceOrigin + dst_merged_dim_data_id).GetOffset();
|
||||
|
||||
// offset w.r.t. normal dimension can be computed at compile-time
|
||||
ford<decltype(dst_normal_dim_access_lengths), DstDimAccessOrder>{}([&](
|
||||
auto dst_normal_dim_access_id) {
|
||||
|
||||
auto dst_normal_dim_data_id = dst_normal_dim_access_id;
|
||||
dst_normal_dim_data_id(dst_vector_access_dim) =
|
||||
dst_normal_dim_access_id[dst_vector_access_dim] * dst_data_per_access;
|
||||
|
||||
dst_vector_t vector_data;
|
||||
|
||||
// pack vector from buffer
|
||||
for(index_t i = 0; i < DstDataPerAccess; ++i)
|
||||
{
|
||||
auto scalar_id = make_zero_array<index_t, nDim>();
|
||||
scalar_id(dst_vector_access_dim) = i;
|
||||
|
||||
const index_t buffer_offset = buffer_desc.GetOffsetFromMultiIndex(
|
||||
dst_merged_dim_data_id + dst_normal_dim_data_id + scalar_id);
|
||||
|
||||
reinterpret_cast<TData*>(&vector_data)[i] = p_buffer[buffer_offset];
|
||||
}
|
||||
|
||||
const index_t dst_normal_offset =
|
||||
DstDesc::GetOffsetFromMultiIndex(dst_normal_dim_data_id);
|
||||
|
||||
// write vector into dst
|
||||
*reinterpret_cast<dst_vector_t*>(&p_dst_tmp[dst_normal_offset]) = vector_data;
|
||||
});
|
||||
});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// T can be Sequence or Array
|
||||
template <class T, bool PositiveDirection>
|
||||
__device__ void MoveSrcSlicingWindow(T step_sizes, integral_constant<bool, PositiveDirection>)
|
||||
{
|
||||
static_if<PositiveDirection>{}([&](auto) {
|
||||
mSrcSliceOrigin += step_sizes;
|
||||
}).Else([&](auto) { mSrcSliceOrigin -= step_sizes; });
|
||||
}
|
||||
|
||||
template <class T, bool PositiveDirection>
|
||||
__device__ void MoveDstSlicingWindow(T step_sizes, integral_constant<bool, PositiveDirection>)
|
||||
{
|
||||
static_if<PositiveDirection>{}([&](auto) {
|
||||
mDstSliceOrigin += step_sizes;
|
||||
}).Else([&](auto) { mDstSliceOrigin -= step_sizes; });
|
||||
}
|
||||
|
||||
private:
|
||||
SrcCoordinate mSrcSliceOrigin;
|
||||
DstCoordinate mDstSliceOrigin;
|
||||
};
|
||||
#endif
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user