mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
reimplement threadwise copy
This commit is contained in:
@@ -13,11 +13,13 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
// slice a (normal or merged) tensor, and copy it into another (normal or merged) tensor
|
||||
// Slice a (normal or merged) tensor, and copy it into another (normal or merged) tensor
|
||||
// memory layout (ordering of dimensions) can be different between src and dst.
|
||||
// on a merged dimension that constains multiple original dimensions,
|
||||
// its sub-length need to evenly divide the length of the last original dimension
|
||||
// so each thread is effectively reading a normal (not merged) tensor
|
||||
// This functions assume each thread is reading and writing a normal (not merged) tensor,
|
||||
// to simplify index calculations. To satisfy this assumption, the user need to make sure
|
||||
// that, on a merged dimension that constains multiple original dimensions, the length of
|
||||
// the last original dimension need to be evenly dividable by its sub-lengths. Also, the
|
||||
// repeat-length on the merged dimension need to be 1.
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
@@ -88,30 +90,55 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
constexpr auto data_per_cluster_per_dims = SubLengths{} * ThreadClusterLengths{};
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto IDim) {
|
||||
static_assert(SliceLengths::Get(IDim) % SubLengths::Get(IDim) == 0,
|
||||
"wrong! cannot evenly divide sliced tensor into sub-tensor");
|
||||
|
||||
static_assert(SliceLengths::Get(IDim) % data_per_cluster_per_dims.Get(IDim) == 0,
|
||||
"wrong! cannot evenly divide sliced tensor into cluster");
|
||||
});
|
||||
|
||||
// on a merged dimension that constains multiple original dimensions,
|
||||
// its sub-length need to evenly divide the length of the last original dimension,
|
||||
// so each thread is effectively reading a normal (not merged) tensor
|
||||
static_for<0, nDim, 1>{}([&](auto IDim) {
|
||||
constexpr auto sub_length = SubLengths::Get(IDim);
|
||||
constexpr auto repeat_lengths = SliceLengths{} / data_per_cluster_per_dims;
|
||||
|
||||
constexpr auto idim_original_src = SrcDesc::GetContainedOriginalDimensions(IDim).Back();
|
||||
static_assert(SrcDesc::GetOriginalTensorDescriptor().GetLength(idim_original_src) %
|
||||
sub_length ==
|
||||
0,
|
||||
"wrong!");
|
||||
// additional check for merged dimension
|
||||
static_for<0, nDim, 1>{}([&](auto IDim_) {
|
||||
// src
|
||||
static_if<SrcDesc::ContainMultipleOriginalDimensions(IDim_)>{}([&](auto) {
|
||||
constexpr auto IDim = decltype(IDim_){};
|
||||
|
||||
constexpr auto idim_original_dst = DstDesc::GetContainedOriginalDimensions(IDim).Back();
|
||||
static_assert(DstDesc::GetOriginalTensorDescriptor().GetLength(idim_original_dst) %
|
||||
sub_length ==
|
||||
0,
|
||||
"wrong!");
|
||||
// on a merged dimension that constains multiple original dimensions,
|
||||
// the length of the last original dimension need to evenly dividable by its
|
||||
// sub-length,
|
||||
// so each thread is effectively reading a normal (not merged) tensor
|
||||
constexpr auto idim_last_original_src =
|
||||
SrcDesc::GetContainedOriginalDimensions(IDim).Back();
|
||||
static_assert(
|
||||
SrcDesc::GetOriginalTensorDescriptor().GetLength(idim_last_original_src) %
|
||||
SubLengths::Get(IDim) ==
|
||||
0,
|
||||
"wrong!");
|
||||
|
||||
// merged dimension should have repeat_lengths = 1
|
||||
static_assert(repeat_lengths[IDim] == 1,
|
||||
"wrong! repeat_lengths shoud be 1 on merged dimension");
|
||||
});
|
||||
|
||||
// dst
|
||||
static_if<DstDesc::ContainMultipleOriginalDimensions(IDim_)>{}([&](auto) {
|
||||
constexpr auto IDim = decltype(IDim_){};
|
||||
|
||||
// on a merged dimension that constains multiple original dimensions,
|
||||
// the length of the last original dimension need to evenly dividable by its
|
||||
// sub-length,
|
||||
// so each thread is effectively reading a normal (not merged) tensor
|
||||
constexpr auto idim_last_original_dst =
|
||||
DstDesc::GetContainedOriginalDimensions(IDim).Back();
|
||||
static_assert(
|
||||
DstDesc::GetOriginalTensorDescriptor().GetLength(idim_last_original_dst) %
|
||||
SubLengths::Get(IDim) ==
|
||||
0,
|
||||
"wrong!");
|
||||
|
||||
// merged dimension should have repeat_lengths = 1
|
||||
static_assert(repeat_lengths[IDim] == 1,
|
||||
"wrong! repeat_lengths shoud be 1 on merged dimension");
|
||||
});
|
||||
});
|
||||
|
||||
// calculate mThreadSrcOffset, mThreadDstOffset
|
||||
@@ -376,7 +403,6 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
};
|
||||
|
||||
template <index_t BlockSize,
|
||||
class TData,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcCoordinate,
|
||||
@@ -428,16 +454,19 @@ struct BlockwiseGenericTensorSliceCopy_v2
|
||||
return RegisterBufferDesc::GetElementSpace();
|
||||
}
|
||||
|
||||
template <class TData>
|
||||
__device__ void RunLoadRegisterBuffer(const TData* p_src, TData* p_buffer) const
|
||||
{
|
||||
mThreadwiseLoad.Run(p_src, p_buffer);
|
||||
}
|
||||
|
||||
template <class TData>
|
||||
__device__ void RunStoreRegisterBuffer(const TData* p_buffer, TData* p_dst) const
|
||||
{
|
||||
mThreadwiseStore.Run(p_buffer, p_dst);
|
||||
}
|
||||
|
||||
template <class TData>
|
||||
__device__ void Run(const TData* p_src, TData* p_dst) const
|
||||
{
|
||||
TData p_buffer[GetRegisterBufferSize()];
|
||||
@@ -466,16 +495,14 @@ struct BlockwiseGenericTensorSliceCopy_v2
|
||||
using RegisterBufferDesc = decltype(make_ConstantTensorDescriptor_packed(SubLengths{}));
|
||||
|
||||
using ThreadwiseLoad =
|
||||
ThreadwiseGenericTensorSliceCopy_v2<TData,
|
||||
SrcDesc,
|
||||
ThreadwiseGenericTensorSliceCopy_v2<SrcDesc,
|
||||
RegisterBufferDesc,
|
||||
SrcCoordinate,
|
||||
NormalTensorCoordinate<RegisterBufferDesc>,
|
||||
SubLengths>;
|
||||
|
||||
using ThreadwiseStore =
|
||||
ThreadwiseGenericTensorSliceCopy_v2<TData,
|
||||
RegisterBufferDesc,
|
||||
ThreadwiseGenericTensorSliceCopy_v2<RegisterBufferDesc,
|
||||
DstDesc,
|
||||
NormalTensorCoordinate<RegisterBufferDesc>,
|
||||
DstCoordinate,
|
||||
|
||||
@@ -106,8 +106,107 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class TData,
|
||||
class SrcDesc,
|
||||
#if 0
|
||||
template <class SrcDesc,
|
||||
class DstDesc,
|
||||
class SliceLengths,
|
||||
class SrcDimAccessOrder,
|
||||
class DstDimAccessOrder,
|
||||
index_t SrcVectorAccessDim,
|
||||
index_t DstVectorAccessDim,
|
||||
index_t SrcDataPerAccess,
|
||||
index_t DstDataPerAccess>
|
||||
struct ThreadwiseGenericTensorSliceCopy_v1
|
||||
{
|
||||
static constexpr index_t nDim = SliceLengths::GetNumOfDimension();
|
||||
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v1(Array<index_t, nDim> src_slice_origin,
|
||||
Array<index_t, nDim> dst_slice_origin)
|
||||
: mSrcSliceOrigin(src_slice_origin), mDstSliceOrigin(dst_slice_origin)
|
||||
{
|
||||
static_assert(nDim == SrcDesc::GetNumOfDimension() &&
|
||||
nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::GetSize() &&
|
||||
nDim == SrcDimAccessOrder::GetSize() &&
|
||||
nDim == DstDimAccessOrder::GetSize(),
|
||||
"wrong! # of dimensions not the same");
|
||||
|
||||
static_assert(is_valid_sequence_map<SrcDimAccessOrder>::{} &&
|
||||
is_valid_sequence_map<DstDimAccessOrder>::{},
|
||||
"wrong! map is not valid");
|
||||
|
||||
static_assert(SliceLengths{}[SrcVectorDim] % SrcDataPerAccess == 0 &&
|
||||
SliceLengths{DstVectorDim} % DstDataPerAccess == 0,
|
||||
"wrong! cannot evenly divide");
|
||||
|
||||
// check vectorized memory access
|
||||
constexpr auto src_vector_access_dim = Number<SrcVectorAccessDIm>{};
|
||||
constexpr auto dst_vector_access_dim = Number<DstVectorAccessDIm>{};
|
||||
|
||||
static_if<!SrcDesc::ContainMultipleOriginalDimensions(
|
||||
src_vector_access_dim)>{}([&](auto fwd) {
|
||||
static_assert(
|
||||
(fwd(SrcDesc{}).GetStrides()[SrcVectorAccessDim] == 1 || SrcDataPerAccess == 1),
|
||||
"wrong! vectorized access is allowed only if stride == 1");
|
||||
}).Else{}([&](auto fwd) {
|
||||
static_assert((SrcDesc::GetLastOriginalDimensionStride(src_vector_access_dim) == 1 ||
|
||||
SrcDataPerAccess == 1),
|
||||
"wrong! vectorized access is allowed only if stride == 1");
|
||||
});
|
||||
|
||||
static_if<!DstDesc::ContainMultipleOriginalDimensions(
|
||||
dst_vector_access_dim)>{}([&](auto fwd) {
|
||||
static_assert(
|
||||
(fwd(DstDesc{}).GetStrides()[DstVectorAccessDim] == 1 || DstDataPerAccess == 1),
|
||||
"wrong! vectorized access is allowed only if stride == 1");
|
||||
}).Else{}([&](auto fwd) {
|
||||
static_assert((DstDesc::GetLastOriginalDimensionStride(dst_vector_access_dim) == 1 ||
|
||||
DstDataPerAccess == 1),
|
||||
"wrong! vectorized access is allowed only if stride == 1");
|
||||
});
|
||||
}
|
||||
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v1()
|
||||
: ThreadwiseGenericTensorSliceCopy_v1(make_zero_array<index_t, nDim>(),
|
||||
make_zero_array<index_t, nDim>())
|
||||
{
|
||||
}
|
||||
|
||||
__device__ void SetSrcSliceOrigin(Array<index_t, nDim> src_slice_origin)
|
||||
{
|
||||
mSrcSliceOrigin = src_slice_origin;
|
||||
}
|
||||
|
||||
__device__ void SetDstSliceOrigin(Array<index_t, nDim> dst_slice_origin)
|
||||
{
|
||||
mDstSliceOrigin = dst_slice_origin;
|
||||
}
|
||||
|
||||
template <class TData>
|
||||
__device__ void Run(const TData* p_src, TData* p_dst) const
|
||||
{
|
||||
constexpr auto buffer_desc = make_ConstantTensorDescriptor_packed(SliceLengths{});
|
||||
|
||||
TData p_buffer[buffer_desc.GetElementSpace()];
|
||||
|
||||
// copy data from src into buffer
|
||||
constexpr auto src_vector_access_dim = Number<SrcVectorAccessDIm>{};
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths::Modify(
|
||||
src_vector_access_dim, SliceLengths::Get(src_vector_access_dim) / SrcDataPerAccess);
|
||||
|
||||
constexpr auto src_access_lengths_in_src_access_order =
|
||||
src_access_lengths.ReorderGivenNew2Old(SrcDimAccessOrder{});
|
||||
|
||||
static_ford<decltype(src_access_lengths_in_src_access_order)>{}([&](auto src_access_id) {});
|
||||
}
|
||||
|
||||
private:
|
||||
Array<index_t, TData> mSrcSliceOrigin;
|
||||
Array<index_t, TData> mDstSliceOrigin;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcCoordinate,
|
||||
class DstCoordinate,
|
||||
@@ -116,18 +215,18 @@ struct ThreadwiseGenericTensorSliceCopy_v2
|
||||
{
|
||||
static constexpr index_t nDim = SrcDesc::GetNumOfDimension();
|
||||
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v2()
|
||||
: mSrcSliceOrigin(make_zero_array<index_t, nDim>()),
|
||||
mDstSliceOrigin(make_zero_array<index_t, nDim>())
|
||||
{
|
||||
}
|
||||
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v2(SrcCoordinate src_slice_origin,
|
||||
DstCoordinate dst_slice_origin)
|
||||
: mSrcSliceOrigin(src_slice_origin), mDstSliceOrigin(dst_slice_origin)
|
||||
{
|
||||
}
|
||||
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v2()
|
||||
: ThreadwiseGenericTensorSliceCopy_v2(make_zero_array<index_t, nDim>(),
|
||||
make_zero_array<index_t, nDim>())
|
||||
{
|
||||
}
|
||||
|
||||
__device__ void SetSrcSliceOrigin(SrcCoordinate src_slice_origin)
|
||||
{
|
||||
mSrcSliceOrigin = src_slice_origin;
|
||||
@@ -148,6 +247,7 @@ struct ThreadwiseGenericTensorSliceCopy_v2
|
||||
}
|
||||
};
|
||||
|
||||
template <class TData>
|
||||
__device__ void Run(const TData* p_src, TData* p_dst) const
|
||||
{
|
||||
constexpr auto buffer_desc = make_ConstantTensorDescriptor_packed(SliceLengths{});
|
||||
@@ -216,6 +316,7 @@ struct ThreadwiseGenericTensorSliceCopy_v2
|
||||
});
|
||||
}
|
||||
|
||||
// T can be Sequence or Array
|
||||
template <class T, bool PositiveDirection>
|
||||
__device__ void MoveSrcSlicingWindow(T step_sizes, integral_constant<bool, PositiveDirection>)
|
||||
{
|
||||
@@ -232,7 +333,7 @@ struct ThreadwiseGenericTensorSliceCopy_v2
|
||||
}).Else([&](auto) { mDstSliceOrigin -= step_sizes; });
|
||||
}
|
||||
|
||||
// private:
|
||||
private:
|
||||
SrcCoordinate mSrcSliceOrigin;
|
||||
DstCoordinate mDstSliceOrigin;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user