reimplement threadwise copy

This commit is contained in:
Chao Liu
2019-08-06 17:41:58 -05:00
parent adc1008836
commit fdcfae3a62
10 changed files with 223 additions and 50 deletions

View File

@@ -13,11 +13,13 @@
namespace ck {
// slice a (normal or merged) tensor, and copy it into another (normal or merged) tensor
// Slice a (normal or merged) tensor, and copy it into another (normal or merged) tensor
// memory layout (ordering of dimensions) can be different between src and dst.
// on a merged dimension that constains multiple original dimensions,
// its sub-length need to evenly divide the length of the last original dimension
// so each thread is effectively reading a normal (not merged) tensor
// This functions assume each thread is reading and writing a normal (not merged) tensor,
// to simplify index calculations. To satisfy this assumption, the user need to make sure
// that, on a merged dimension that constains multiple original dimensions, the length of
// the last original dimension need to be evenly dividable by its sub-lengths. Also, the
// repeat-length on the merged dimension need to be 1.
template <index_t BlockSize,
class Float,
class SrcDesc,
@@ -88,30 +90,55 @@ struct BlockwiseGenericTensorSliceCopy_v1
constexpr auto data_per_cluster_per_dims = SubLengths{} * ThreadClusterLengths{};
static_for<0, nDim, 1>{}([&](auto IDim) {
static_assert(SliceLengths::Get(IDim) % SubLengths::Get(IDim) == 0,
"wrong! cannot evenly divide sliced tensor into sub-tensor");
static_assert(SliceLengths::Get(IDim) % data_per_cluster_per_dims.Get(IDim) == 0,
"wrong! cannot evenly divide sliced tensor into cluster");
});
// on a merged dimension that constains multiple original dimensions,
// its sub-length need to evenly divide the length of the last original dimension,
// so each thread is effectively reading a normal (not merged) tensor
static_for<0, nDim, 1>{}([&](auto IDim) {
constexpr auto sub_length = SubLengths::Get(IDim);
constexpr auto repeat_lengths = SliceLengths{} / data_per_cluster_per_dims;
constexpr auto idim_original_src = SrcDesc::GetContainedOriginalDimensions(IDim).Back();
static_assert(SrcDesc::GetOriginalTensorDescriptor().GetLength(idim_original_src) %
sub_length ==
0,
"wrong!");
// additional check for merged dimension
static_for<0, nDim, 1>{}([&](auto IDim_) {
// src
static_if<SrcDesc::ContainMultipleOriginalDimensions(IDim_)>{}([&](auto) {
constexpr auto IDim = decltype(IDim_){};
constexpr auto idim_original_dst = DstDesc::GetContainedOriginalDimensions(IDim).Back();
static_assert(DstDesc::GetOriginalTensorDescriptor().GetLength(idim_original_dst) %
sub_length ==
0,
"wrong!");
// on a merged dimension that constains multiple original dimensions,
// the length of the last original dimension need to evenly dividable by its
// sub-length,
// so each thread is effectively reading a normal (not merged) tensor
constexpr auto idim_last_original_src =
SrcDesc::GetContainedOriginalDimensions(IDim).Back();
static_assert(
SrcDesc::GetOriginalTensorDescriptor().GetLength(idim_last_original_src) %
SubLengths::Get(IDim) ==
0,
"wrong!");
// merged dimension should have repeat_lengths = 1
static_assert(repeat_lengths[IDim] == 1,
"wrong! repeat_lengths shoud be 1 on merged dimension");
});
// dst
static_if<DstDesc::ContainMultipleOriginalDimensions(IDim_)>{}([&](auto) {
constexpr auto IDim = decltype(IDim_){};
// on a merged dimension that constains multiple original dimensions,
// the length of the last original dimension need to evenly dividable by its
// sub-length,
// so each thread is effectively reading a normal (not merged) tensor
constexpr auto idim_last_original_dst =
DstDesc::GetContainedOriginalDimensions(IDim).Back();
static_assert(
DstDesc::GetOriginalTensorDescriptor().GetLength(idim_last_original_dst) %
SubLengths::Get(IDim) ==
0,
"wrong!");
// merged dimension should have repeat_lengths = 1
static_assert(repeat_lengths[IDim] == 1,
"wrong! repeat_lengths shoud be 1 on merged dimension");
});
});
// calculate mThreadSrcOffset, mThreadDstOffset
@@ -376,7 +403,6 @@ struct BlockwiseGenericTensorSliceCopy_v1
};
template <index_t BlockSize,
class TData,
class SrcDesc,
class DstDesc,
class SrcCoordinate,
@@ -428,16 +454,19 @@ struct BlockwiseGenericTensorSliceCopy_v2
return RegisterBufferDesc::GetElementSpace();
}
template <class TData>
__device__ void RunLoadRegisterBuffer(const TData* p_src, TData* p_buffer) const
{
mThreadwiseLoad.Run(p_src, p_buffer);
}
template <class TData>
__device__ void RunStoreRegisterBuffer(const TData* p_buffer, TData* p_dst) const
{
mThreadwiseStore.Run(p_buffer, p_dst);
}
template <class TData>
__device__ void Run(const TData* p_src, TData* p_dst) const
{
TData p_buffer[GetRegisterBufferSize()];
@@ -466,16 +495,14 @@ struct BlockwiseGenericTensorSliceCopy_v2
using RegisterBufferDesc = decltype(make_ConstantTensorDescriptor_packed(SubLengths{}));
using ThreadwiseLoad =
ThreadwiseGenericTensorSliceCopy_v2<TData,
SrcDesc,
ThreadwiseGenericTensorSliceCopy_v2<SrcDesc,
RegisterBufferDesc,
SrcCoordinate,
NormalTensorCoordinate<RegisterBufferDesc>,
SubLengths>;
using ThreadwiseStore =
ThreadwiseGenericTensorSliceCopy_v2<TData,
RegisterBufferDesc,
ThreadwiseGenericTensorSliceCopy_v2<RegisterBufferDesc,
DstDesc,
NormalTensorCoordinate<RegisterBufferDesc>,
DstCoordinate,

View File

@@ -106,8 +106,107 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
#endif
}
template <class TData,
class SrcDesc,
#if 0
template <class SrcDesc,
class DstDesc,
class SliceLengths,
class SrcDimAccessOrder,
class DstDimAccessOrder,
index_t SrcVectorAccessDim,
index_t DstVectorAccessDim,
index_t SrcDataPerAccess,
index_t DstDataPerAccess>
struct ThreadwiseGenericTensorSliceCopy_v1
{
static constexpr index_t nDim = SliceLengths::GetNumOfDimension();
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v1(Array<index_t, nDim> src_slice_origin,
Array<index_t, nDim> dst_slice_origin)
: mSrcSliceOrigin(src_slice_origin), mDstSliceOrigin(dst_slice_origin)
{
static_assert(nDim == SrcDesc::GetNumOfDimension() &&
nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::GetSize() &&
nDim == SrcDimAccessOrder::GetSize() &&
nDim == DstDimAccessOrder::GetSize(),
"wrong! # of dimensions not the same");
static_assert(is_valid_sequence_map<SrcDimAccessOrder>::{} &&
is_valid_sequence_map<DstDimAccessOrder>::{},
"wrong! map is not valid");
static_assert(SliceLengths{}[SrcVectorDim] % SrcDataPerAccess == 0 &&
SliceLengths{DstVectorDim} % DstDataPerAccess == 0,
"wrong! cannot evenly divide");
// check vectorized memory access
constexpr auto src_vector_access_dim = Number<SrcVectorAccessDIm>{};
constexpr auto dst_vector_access_dim = Number<DstVectorAccessDIm>{};
static_if<!SrcDesc::ContainMultipleOriginalDimensions(
src_vector_access_dim)>{}([&](auto fwd) {
static_assert(
(fwd(SrcDesc{}).GetStrides()[SrcVectorAccessDim] == 1 || SrcDataPerAccess == 1),
"wrong! vectorized access is allowed only if stride == 1");
}).Else{}([&](auto fwd) {
static_assert((SrcDesc::GetLastOriginalDimensionStride(src_vector_access_dim) == 1 ||
SrcDataPerAccess == 1),
"wrong! vectorized access is allowed only if stride == 1");
});
static_if<!DstDesc::ContainMultipleOriginalDimensions(
dst_vector_access_dim)>{}([&](auto fwd) {
static_assert(
(fwd(DstDesc{}).GetStrides()[DstVectorAccessDim] == 1 || DstDataPerAccess == 1),
"wrong! vectorized access is allowed only if stride == 1");
}).Else{}([&](auto fwd) {
static_assert((DstDesc::GetLastOriginalDimensionStride(dst_vector_access_dim) == 1 ||
DstDataPerAccess == 1),
"wrong! vectorized access is allowed only if stride == 1");
});
}
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v1()
: ThreadwiseGenericTensorSliceCopy_v1(make_zero_array<index_t, nDim>(),
make_zero_array<index_t, nDim>())
{
}
__device__ void SetSrcSliceOrigin(Array<index_t, nDim> src_slice_origin)
{
mSrcSliceOrigin = src_slice_origin;
}
__device__ void SetDstSliceOrigin(Array<index_t, nDim> dst_slice_origin)
{
mDstSliceOrigin = dst_slice_origin;
}
template <class TData>
__device__ void Run(const TData* p_src, TData* p_dst) const
{
constexpr auto buffer_desc = make_ConstantTensorDescriptor_packed(SliceLengths{});
TData p_buffer[buffer_desc.GetElementSpace()];
// copy data from src into buffer
constexpr auto src_vector_access_dim = Number<SrcVectorAccessDIm>{};
constexpr auto src_access_lengths = SliceLengths::Modify(
src_vector_access_dim, SliceLengths::Get(src_vector_access_dim) / SrcDataPerAccess);
constexpr auto src_access_lengths_in_src_access_order =
src_access_lengths.ReorderGivenNew2Old(SrcDimAccessOrder{});
static_ford<decltype(src_access_lengths_in_src_access_order)>{}([&](auto src_access_id) {});
}
private:
Array<index_t, TData> mSrcSliceOrigin;
Array<index_t, TData> mDstSliceOrigin;
};
#endif
template <class SrcDesc,
class DstDesc,
class SrcCoordinate,
class DstCoordinate,
@@ -116,18 +215,18 @@ struct ThreadwiseGenericTensorSliceCopy_v2
{
static constexpr index_t nDim = SrcDesc::GetNumOfDimension();
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v2()
: mSrcSliceOrigin(make_zero_array<index_t, nDim>()),
mDstSliceOrigin(make_zero_array<index_t, nDim>())
{
}
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v2(SrcCoordinate src_slice_origin,
DstCoordinate dst_slice_origin)
: mSrcSliceOrigin(src_slice_origin), mDstSliceOrigin(dst_slice_origin)
{
}
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v2()
: ThreadwiseGenericTensorSliceCopy_v2(make_zero_array<index_t, nDim>(),
make_zero_array<index_t, nDim>())
{
}
__device__ void SetSrcSliceOrigin(SrcCoordinate src_slice_origin)
{
mSrcSliceOrigin = src_slice_origin;
@@ -148,6 +247,7 @@ struct ThreadwiseGenericTensorSliceCopy_v2
}
};
template <class TData>
__device__ void Run(const TData* p_src, TData* p_dst) const
{
constexpr auto buffer_desc = make_ConstantTensorDescriptor_packed(SliceLengths{});
@@ -216,6 +316,7 @@ struct ThreadwiseGenericTensorSliceCopy_v2
});
}
// T can be Sequence or Array
template <class T, bool PositiveDirection>
__device__ void MoveSrcSlicingWindow(T step_sizes, integral_constant<bool, PositiveDirection>)
{
@@ -232,7 +333,7 @@ struct ThreadwiseGenericTensorSliceCopy_v2
}).Else([&](auto) { mDstSliceOrigin -= step_sizes; });
}
// private:
private:
SrcCoordinate mSrcSliceOrigin;
DstCoordinate mDstSliceOrigin;
};