mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
experimenting new merged tensor copy
This commit is contained in:
@@ -295,9 +295,24 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
|
||||
// do work
|
||||
for(index_t e = 0; e < E; e += EPerBlock)
|
||||
{
|
||||
// marching slicing window
|
||||
#if 0 // debug
|
||||
blockwise_in_copy.Run(p_in_global, p_in_block);
|
||||
blockwise_wei_copy.Run(p_wei_global, p_wei_block);
|
||||
#else
|
||||
using InSrcMergedDimSubLengthsHack = Sequence<1, 1, 1, 1>;
|
||||
using InDstMergedDimSubLengthsHack = Sequence<1, 1, 1, 1>;
|
||||
blockwise_in_copy.Run_hack(p_in_global,
|
||||
p_in_block,
|
||||
InSrcMergedDimSubLengthsHack{},
|
||||
InDstMergedDimSubLengthsHack{});
|
||||
|
||||
using WeiSrcMergedDimSubLengthsHack = Sequence<1, 1>;
|
||||
using WeiDstMergedDimSubLengthsHack = Sequence<1, 1>;
|
||||
blockwise_wei_copy.Run_hack(p_wei_global,
|
||||
p_wei_block,
|
||||
WeiSrcMergedDimSubLengthsHack{},
|
||||
WeiDstMergedDimSubLengthsHack{});
|
||||
#endif
|
||||
|
||||
__syncthreads();
|
||||
|
||||
|
||||
@@ -418,6 +418,22 @@ struct BlockwiseGenericTensorSliceCopy_v2
|
||||
mThreadwiseCopy.Run(p_src, p_dst);
|
||||
}
|
||||
|
||||
template <class SrcMergedDimSubLengthsHack, class DstMergedDimSubLengthsHack>
|
||||
__device__ void Run_hack(const TData* p_src,
|
||||
TData* p_dst,
|
||||
SrcMergedDimSubLengthsHack,
|
||||
DstMergedDimSubLengthsHack) const
|
||||
{
|
||||
// hacks to isolate merged dimension from normal dimensions, and calculate their offset
|
||||
// seperately
|
||||
// SrcMergedDimSliceLengthsHack has entry same as SliceLengths on src merged dimensions,
|
||||
// but 1 on normal dimensions;
|
||||
// SrcNormalDimSliceLengthsHack has entry same as SliceLengths on src normal dimensions,
|
||||
// but 1 on merged dimensions;
|
||||
mThreadwiseCopy.Run_hack(
|
||||
p_src, p_dst, SrcMergedDimSubLengthsHack{}, DstMergedDimSubLengthsHack{});
|
||||
}
|
||||
|
||||
__device__ void MoveSrcSlicingWindow(Array<index_t, nDim> step_sizes, bool positive_direction)
|
||||
{
|
||||
mThreadwiseCopy.MoveSrcSlicingWindow(step_sizes, positive_direction);
|
||||
|
||||
@@ -140,25 +140,104 @@ struct ThreadwiseGenericTensorSliceCopy_v2
|
||||
|
||||
__device__ void Run(const TData* p_src, TData* p_dst) const
|
||||
{
|
||||
constexpr auto buffer_desc = make_ConstantTensorDescriptor_packed(SliceLengths{});
|
||||
|
||||
TData p_buffer_[buffer_desc.GetElementSpace()];
|
||||
TData* p_buffer = p_buffer_;
|
||||
|
||||
#if 0
|
||||
static_ford<SliceLengths>{}([&](auto data_id) {
|
||||
p_dst[(mDstSliceOrigin + data_id).GetOffset()] =
|
||||
p_buffer[buffer_desc.GetOffsetFromMultiIndex(data_id)] =
|
||||
p_src[(mSrcSliceOrigin + data_id).GetOffset()];
|
||||
});
|
||||
|
||||
static_ford<SliceLengths>{}([&](auto data_id) {
|
||||
p_dst[(mDstSliceOrigin + data_id).GetOffset()] =
|
||||
p_buffer[buffer_desc.GetOffsetFromMultiIndex(data_id)];
|
||||
});
|
||||
#elif 1
|
||||
auto src_slice_origin = mSrcSliceOrigin;
|
||||
auto dst_slice_origin = mDstSliceOrigin;
|
||||
|
||||
p_src += src_slice_origin.RepositionOrigin();
|
||||
p_dst += dst_slice_origin.RepositionOrigin();
|
||||
const TData* p_src_tmp = p_src + src_slice_origin.RepositionOrigin();
|
||||
TData* p_dst_tmp = p_dst + dst_slice_origin.RepositionOrigin();
|
||||
|
||||
static_ford<SliceLengths>{}([&](auto data_id) {
|
||||
p_dst[(dst_slice_origin + data_id).GetOffset()] =
|
||||
p_src[(src_slice_origin + data_id).GetOffset()];
|
||||
p_buffer[buffer_desc.GetOffsetFromMultiIndex(data_id)] =
|
||||
p_src_tmp[(src_slice_origin + data_id).GetOffset()];
|
||||
});
|
||||
|
||||
static_ford<SliceLengths>{}([&](auto data_id) {
|
||||
p_dst_tmp[(dst_slice_origin + data_id).GetOffset()] =
|
||||
p_buffer[buffer_desc.GetOffsetFromMultiIndex(data_id)];
|
||||
});
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class SrcMergedDimSliceLengthsHack, class DstMergedDimSliceLengthsHack>
|
||||
__device__ void Run_hack(const TData* p_src,
|
||||
TData* p_dst,
|
||||
SrcMergedDimSliceLengthsHack,
|
||||
DstMergedDimSliceLengthsHack) const
|
||||
{
|
||||
constexpr auto buffer_desc = make_ConstantTensorDescriptor_packed(SliceLengths{});
|
||||
|
||||
TData p_buffer_[buffer_desc.GetElementSpace()];
|
||||
TData* p_buffer = p_buffer_;
|
||||
|
||||
// hacks to isolate merged dimension from normal dimensions, and calculate their offset
|
||||
// seperately
|
||||
// SrcMergedDimSliceLengthsHack has entry same as SliceLengths on src merged dimensions,
|
||||
// but 1 on normal dimensions;
|
||||
// SrcNormalDimSliceLengthsHack has entry same as SliceLengths on src normal dimensions,
|
||||
// but 1 on merged dimensions;
|
||||
using SrcNormalDimSliceLengthsHack =
|
||||
decltype((SliceLengths{} + Number<1>{}) - SrcMergedDimSliceLengthsHack{});
|
||||
|
||||
static_ford<SrcMergedDimSliceLengthsHack>{}([&](auto merged_dim_data_id_) {
|
||||
constexpr auto merged_dim_data_id = decltype(merged_dim_data_id_){};
|
||||
|
||||
const TData* p_src_tmp = p_src + (mSrcSliceOrigin + merged_dim_data_id).GetOffset();
|
||||
|
||||
static_ford<SrcNormalDimSliceLengthsHack>{}([&](auto normal_dim_data_id_) {
|
||||
constexpr auto normal_dim_data_id = decltype(normal_dim_data_id_){};
|
||||
|
||||
constexpr index_t buffer_offset =
|
||||
buffer_desc.GetOffsetFromMultiIndex(merged_dim_data_id + normal_dim_data_id);
|
||||
|
||||
constexpr index_t src_normal_offset =
|
||||
SrcDesc::GetOffsetFromMultiIndex(normal_dim_data_id);
|
||||
|
||||
p_buffer[buffer_offset] = p_src_tmp[src_normal_offset];
|
||||
});
|
||||
});
|
||||
|
||||
// DstMergedDimSliceLengthsHack has entry same as SliceLengths on dst merged dimensions,
|
||||
// but 1 on normal dimensions;
|
||||
// DstNormalDimSliceLengthsHack has entry same as SliceLengths on dst normal dimensions,
|
||||
// but 1 on merged dimensions;
|
||||
using DstNormalDimSliceLengthsHack =
|
||||
decltype((SliceLengths{} + Number<1>{}) - DstMergedDimSliceLengthsHack{});
|
||||
|
||||
static_ford<DstMergedDimSliceLengthsHack>{}([&](auto merged_dim_data_id_) {
|
||||
constexpr auto merged_dim_data_id = decltype(merged_dim_data_id_){};
|
||||
|
||||
TData* p_dst_tmp = p_dst + (mDstSliceOrigin + merged_dim_data_id).GetOffset();
|
||||
|
||||
static_ford<DstNormalDimSliceLengthsHack>{}([&](auto normal_dim_data_id_) {
|
||||
constexpr auto normal_dim_data_id = decltype(normal_dim_data_id_){};
|
||||
|
||||
constexpr index_t buffer_offset =
|
||||
buffer_desc.GetOffsetFromMultiIndex(merged_dim_data_id + normal_dim_data_id);
|
||||
|
||||
constexpr index_t dst_normal_offset =
|
||||
DstDesc::GetOffsetFromMultiIndex(normal_dim_data_id);
|
||||
|
||||
p_dst_tmp[dst_normal_offset] = p_buffer[buffer_offset];
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
__device__ void MoveSrcSlicingWindow(Array<index_t, nDim> step_sizes, bool positive_direction)
|
||||
{
|
||||
if(positive_direction)
|
||||
|
||||
Reference in New Issue
Block a user