From c01af89928da9082251860182baf43fb2153209c Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Sat, 3 Aug 2019 00:02:24 -0500 Subject: [PATCH] added new tensor copy operator --- ...tion_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp | 28 +--- ..._v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp | 85 +++++++--- ...tion_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp | 28 +--- .../blockwise_generic_tensor_slice_copy.hpp | 158 +++++++++++------- .../threadwise_generic_tensor_slice_copy.hpp | 56 ++----- .../include/utility/Sequence.hpp | 51 ++++-- ...tion_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp | 2 +- ...tion_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp | 1 + driver/src/driver.cpp | 2 +- 9 files changed, 215 insertions(+), 196 deletions(-) diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp index d25469ba21..b83a58bfad 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp @@ -295,27 +295,8 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw // do work for(index_t e = 0; e < E; e += EPerBlock) { -#if 0 // debug blockwise_in_copy.Run(p_in_global, p_in_block); blockwise_wei_copy.Run(p_wei_global, p_wei_block); -#else - using InSrcMergedDimSubLengthsHack = Sequence; - using InDstMergedDimSubLengthsHack = Sequence<1, 1, 1, 1>; - blockwise_in_copy.Run_hack(p_in_global, - p_in_block, - InSrcMergedDimSubLengthsHack{}, - InDstMergedDimSubLengthsHack{}); - - using WeiSrcMergedDimSubLengthsHack = Sequence<1, 1>; - using WeiDstMergedDimSubLengthsHack = Sequence<1, 1>; - blockwise_wei_copy.Run_hack(p_wei_global, - p_wei_block, - WeiSrcMergedDimSubLengthsHack{}, - WeiDstMergedDimSubLengthsHack{}); -#endif __syncthreads(); @@ -391,10 +372,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw arithmetic_sequence_gen<0, 8, 1>::type{}, Number<1>{}); #else - - using OutSrcMergedDimSliceLengthsHack = Sequence<1, 1, 1, 1, 1, 1, 1, 1>; - using OutDstMergedDimSliceLengthsHack = Sequence<1, 1, 1, 1, 1, 1, 1, 1>; - ThreadwiseGenericTensorSliceCopy_v2< Float, decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc), @@ -403,10 +380,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw MergedTensorCoordinate, decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths())>( {0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0}) - .Run_hack(p_out_thread, - p_out_thread_on_global, - OutSrcMergedDimSliceLengthsHack{}, - OutDstMergedDimSliceLengthsHack{}); + .Run(p_out_thread, p_out_thread_on_global); #endif } } diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp index 82b097c5e6..f1386b1d92 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp @@ -155,6 +155,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer static_assert(in_e_n1_b_n2_block_desc.GetStride(I1) % GemmDataPerReadB == 0, "GemmDataPerReadB alignment requirement is not satisfied"); +#if 1 // debug // input blockwise copy // slice a merged tensor, reorder and copy to a normal tensor // this copy operator already has blockwise offset built-in @@ -172,6 +173,19 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer InBlockCopySrcDataPerRead_B, InBlockCopyDstDataPerWrite_N2>( {0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0}); +#else + auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v2< + BlockSize, + Float, + decltype(in_e_n1_b_n2_global_merged_desc), + decltype(in_e_n1_b_n2_block_desc), + MergedTensorCoordinate, + NormalTensorCoordinate, + decltype(in_e_n1_b_n2_block_desc.GetLengths()), + InBlockCopySubLengths_E_N1_B_N2, + InBlockCopyClusterLengths_E_N1_B_N2, + InBlockCopyThreadClusterArrangeOrder>({0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0}); +#endif // weight tensor // tensor descriptor in device memory, src of blockwise copy @@ -184,6 +198,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer Sequence{}, Number{}); +#if 1 // debug // operator for blockwise copy of weight into LDS // slice a tensor, and copy it into another tensor // this copy operator already have blockwise offset built-in @@ -201,6 +216,19 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer WeiBlockCopySrcDataPerRead_E, WeiBlockCopyDstDataPerWrite_K>( {0, k_block_data_on_global}, {0, 0}); +#else + auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v2< + BlockSize, + Float, + decltype(wei_e_k_global_desc), + decltype(wei_e_k_block_desc), + NormalTensorCoordinate, + NormalTensorCoordinate, + decltype(wei_e_k_block_desc.GetLengths()), + WeiBlockCopySubLengths_E_K, + WeiBlockCopyClusterLengths_E_K, + WeiBlockCopyThreadClusterArrangeOrder>({0, k_block_data_on_global}, {0, 0}); +#endif // GEMM definition // c_mtx += transpose(a_mtx) * b_mtx @@ -291,54 +319,61 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer Float* p_wei_block_next = even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double; - Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; - Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; + Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()]; + Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; +#if 1 blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number{}, True); p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0); +#else + blockwise_in_copy.MoveSrcSlicingWindow({EPerBlock, 0, 0, 0}, true); + blockwise_wei_copy.MoveSrcSlicingWindow({EPerBlock, 0}, true); +#endif __syncthreads(); // LDS doubel buffer: load next data from device mem - blockwise_in_copy.RunLoadRegisterClipboard(p_in_global, p_in_register_clipboard); - blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_block_on_global, - p_wei_register_clipboard); + blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer); + blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, + p_wei_register_buffer); // LDS double buffer: GEMM on current data blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread); // LDS double buffer: store next data to LDS - blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, - p_in_block_next); - blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, - p_wei_block_next); + blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_next); + blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_next); } } // LDS double buffer: tail { - Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; - Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; + Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()]; + Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; - // even iteration +// even iteration +#if 1 blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number{}, True); p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0); +#else + blockwise_in_copy.MoveSrcSlicingWindow({EPerBlock, 0, 0, 0}, true); + blockwise_wei_copy.MoveSrcSlicingWindow({EPerBlock, 0}, true); +#endif __syncthreads(); // LDS doubel buffer: load next data from device mem - blockwise_in_copy.RunLoadRegisterClipboard(p_in_global, p_in_register_clipboard); - blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_block_on_global, - p_wei_register_clipboard); + blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer); + blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, p_wei_register_buffer); // LDS double buffer: GEMM on current data blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); // LDS double buffer: store next data to LDS - blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, - p_in_block_double + in_block_space); - blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, - p_wei_block_double + wei_block_space); + blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, + p_in_block_double + in_block_space); + blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, + p_wei_block_double + wei_block_space); // odd iteration __syncthreads(); @@ -396,6 +431,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex( k_thread_data_on_global, 0, b_thread_data_on_global, 0); +#if 1 // debug threadwise_generic_tensor_slice_copy_v1( out_n0_n1_n2_k0_k1_k2_h_w_thread_desc, p_out_thread, @@ -406,6 +442,17 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths(), arithmetic_sequence_gen<0, 8, 1>::type{}, Number<1>{}); +#else + ThreadwiseGenericTensorSliceCopy_v2< + Float, + decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc), + decltype(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc), + NormalTensorCoordinate, + MergedTensorCoordinate, + decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths())>( + {0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0}) + .Run(p_out_thread, p_out_thread_on_global); +#endif } } }; diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp index 4aace546c4..b882f8b20c 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp @@ -11,7 +11,7 @@ namespace ck { -// B = merge(N, H, W) +// B = merge(N, Ho, Wo) template ; - blockwise_in_copy.Run_hack(p_in_global, - p_in_block, - InSrcMergedDimSubLengthsHack{}, - InDstMergedDimSubLengthsHack{}); - - using WeiSrcMergedDimSubLengthsHack = Sequence<1, 1>; - using WeiDstMergedDimSubLengthsHack = Sequence<1, 1>; - blockwise_wei_copy.Run_hack(p_wei_global, - p_wei_block, - WeiSrcMergedDimSubLengthsHack{}, - WeiDstMergedDimSubLengthsHack{}); -#endif __syncthreads(); @@ -318,17 +302,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw for(index_t nrepeat = 0; nrepeat < GemmNRepeat; ++nrepeat) { -#if 0 threadwise_out_copy.Run(p_out_thread, p_out_global); -#else - using OutSrcMergedDimSubLengthsHack = Sequence<1, 1, 1>; - using OutDstMergedDimSubLengthsHack = - Sequence<1, 1, OutThreadCopySliceLengths{}[2]>; - threadwise_out_copy.Run_hack(p_out_thread, - p_out_global, - OutSrcMergedDimSubLengthsHack{}, - OutDstMergedDimSubLengthsHack{}); -#endif threadwise_out_copy.MoveSrcSlicingWindow({0, 0, GemmNPerThreadSubC}, true); threadwise_out_copy.MoveDstSlicingWindow({0, 0, B1}, true); diff --git a/composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp b/composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp index 469de6b09b..fa2466be91 100644 --- a/composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp @@ -24,7 +24,7 @@ template {}([&](auto IDim) { static_assert(SliceLengths::Get(IDim) % SubLengths::Get(IDim) == 0, @@ -160,9 +161,9 @@ struct BlockwiseGenericTensorSliceCopy_v1 mThreadDstPartialOffsets, math::plus{}, static_cast(0)); } - __device__ static constexpr index_t GetRegisterClipboardSize() + __device__ static constexpr index_t GetRegisterBufferSize() { - constexpr auto repeat_lengths = SliceLengths{} / (SubLengths{} * DataClusterLengths{}); + constexpr auto repeat_lengths = SliceLengths{} / (SubLengths{} * ThreadClusterLengths{}); constexpr auto thread_tensor_desc = make_ConstantTensorDescriptor_packed(SubLengths{} * repeat_lengths); @@ -170,14 +171,15 @@ struct BlockwiseGenericTensorSliceCopy_v1 return thread_tensor_desc.GetElementSpace(); } - __device__ void RunLoadRegisterClipboard(const Float* __restrict__ p_src, - Float* __restrict__ p_clipboard) const + __device__ void RunLoadRegisterBuffer(const Float* __restrict__ p_src, + Float* __restrict__ p_Buffer) const { constexpr auto thread_sub_tensor_lengths = SubLengths{}; - constexpr auto data_per_cluster_per_dims = thread_sub_tensor_lengths * DataClusterLengths{}; + constexpr auto data_per_cluster_per_dims = + thread_sub_tensor_lengths * ThreadClusterLengths{}; - constexpr auto repeat_lengths = SliceLengths{} / (SubLengths{} * DataClusterLengths{}); + constexpr auto repeat_lengths = SliceLengths{} / (SubLengths{} * ThreadClusterLengths{}); constexpr auto thread_tensor_desc = make_ConstantTensorDescriptor_packed(thread_sub_tensor_lengths * repeat_lengths); @@ -187,25 +189,24 @@ struct BlockwiseGenericTensorSliceCopy_v1 constexpr auto src_thread_data_multi_id_begin = repeat_multi_id * data_per_cluster_per_dims; - constexpr auto clipboard_data_multi_id_begin = - repeat_multi_id * thread_sub_tensor_lengths; + constexpr auto Buffer_data_multi_id_begin = repeat_multi_id * thread_sub_tensor_lengths; constexpr index_t src_offset = SrcDesc::GetOffsetFromMultiIndex(src_thread_data_multi_id_begin); - constexpr index_t clipboard_offset = - thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin); + constexpr index_t Buffer_offset = + thread_tensor_desc.GetOffsetFromMultiIndex(Buffer_data_multi_id_begin); #else ford{}([&](auto repeat_multi_id) { const auto src_thread_data_multi_id_begin = repeat_multi_id * data_per_cluster_per_dims; - const auto clipboard_data_multi_id_begin = repeat_multi_id * thread_sub_tensor_lengths; + const auto Buffer_data_multi_id_begin = repeat_multi_id * thread_sub_tensor_lengths; const index_t src_offset = SrcDesc::GetOffsetFromMultiIndex(src_thread_data_multi_id_begin); - const index_t clipboard_offset = - thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin); + const index_t Buffer_offset = + thread_tensor_desc.GetOffsetFromMultiIndex(Buffer_data_multi_id_begin); #endif // By position the origin of the per-thread window at the point, where multi-index @@ -219,7 +220,7 @@ struct BlockwiseGenericTensorSliceCopy_v1 p_src + src_offset + mThreadSrcOffset, make_zero_array(), thread_tensor_desc, - p_clipboard + clipboard_offset, + p_Buffer + Buffer_offset, make_zero_array(), thread_sub_tensor_lengths, SrcAccessOrder{}, @@ -227,38 +228,38 @@ struct BlockwiseGenericTensorSliceCopy_v1 }); } - __device__ void RunStoreRegisterClipboard(const Float* __restrict__ p_clipboard, - Float* __restrict__ p_dst) const + __device__ void RunStoreRegisterBuffer(const Float* __restrict__ p_Buffer, + Float* __restrict__ p_dst) const { constexpr auto thread_sub_tensor_lengths = SubLengths{}; - constexpr auto data_per_cluster_per_dims = thread_sub_tensor_lengths * DataClusterLengths{}; + constexpr auto data_per_cluster_per_dims = + thread_sub_tensor_lengths * ThreadClusterLengths{}; - constexpr auto repeat_lengths = SliceLengths{} / (SubLengths{} * DataClusterLengths{}); + constexpr auto repeat_lengths = SliceLengths{} / (SubLengths{} * ThreadClusterLengths{}); constexpr auto thread_tensor_desc = make_ConstantTensorDescriptor_packed(thread_sub_tensor_lengths * repeat_lengths); #if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 static_ford{}([&](auto repeat_multi_id) { - constexpr auto clipboard_data_multi_id_begin = - repeat_multi_id * thread_sub_tensor_lengths; + constexpr auto Buffer_data_multi_id_begin = repeat_multi_id * thread_sub_tensor_lengths; constexpr auto dst_data_multi_id_begin = repeat_multi_id * data_per_cluster_per_dims; - constexpr index_t clipboard_offset = - thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin); + constexpr index_t Buffer_offset = + thread_tensor_desc.GetOffsetFromMultiIndex(Buffer_data_multi_id_begin); constexpr index_t dst_offset = DstDesc::GetOffsetFromMultiIndex(dst_data_multi_id_begin); #else ford{}([&](auto repeat_multi_id) { - const auto clipboard_data_multi_id_begin = repeat_multi_id * thread_sub_tensor_lengths; + const auto Buffer_data_multi_id_begin = repeat_multi_id * thread_sub_tensor_lengths; const auto dst_data_multi_id_begin = repeat_multi_id * data_per_cluster_per_dims; - const index_t clipboard_offset = - thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin); + const index_t Buffer_offset = + thread_tensor_desc.GetOffsetFromMultiIndex(Buffer_data_multi_id_begin); const index_t dst_offset = DstDesc::GetOffsetFromMultiIndex(dst_data_multi_id_begin); #endif @@ -271,7 +272,7 @@ struct BlockwiseGenericTensorSliceCopy_v1 // If in the future, you want to enable SubLengths > 1 at the merged dimension, // special care in implementation is needed threadwise_generic_tensor_slice_copy_v1(thread_tensor_desc, - p_clipboard + clipboard_offset, + p_Buffer + Buffer_offset, make_zero_array(), DstDesc{}, p_dst + dst_offset + mThreadDstOffset, @@ -284,10 +285,10 @@ struct BlockwiseGenericTensorSliceCopy_v1 __device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const { - Float p_clipboard[GetRegisterClipboardSize()]; + Float p_Buffer[GetRegisterBufferSize()]; - RunLoadRegisterClipboard(p_src, p_clipboard); - RunStoreRegisterClipboard(p_clipboard, p_dst); + RunLoadRegisterBuffer(p_src, p_Buffer); + RunStoreRegisterBuffer(p_Buffer, p_dst); } // When moving the slicing windows along a merged dimension, if the strides of the @@ -382,24 +383,30 @@ template struct BlockwiseGenericTensorSliceCopy_v2 { - using ThreadwiseCopy = ThreadwiseGenericTensorSliceCopy_v2; - static constexpr index_t nDim = SrcDesc::GetNumOfDimension(); __device__ constexpr BlockwiseGenericTensorSliceCopy_v2(SrcCoordinate src_block_slice_origin, DstCoordinate dst_block_slice_origin) { + static_assert(nDim == SrcDesc::GetNumOfDimension() && + nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::GetSize() && + nDim == SubLengths::GetSize() && + nDim == ThreadClusterLengths::GetSize() && + nDim == ThreadClusterArrangeOrder::GetSize(), + "wrong! nDim not consistent"); + + static_assert(is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor_packed( - DataClusterLengths::ReorderGivenNew2Old(ThreadClusterArrangeOrder{})); + ThreadClusterLengths::ReorderGivenNew2Old(ThreadClusterArrangeOrder{})); + + static_assert(BlockSize == thread_cluster_desc.GetElementSize(), + "wrong! BlockSize not consistent with ThreadClusterLengths"); const auto thread_cluster_multi_id = thread_cluster_desc.GetMultiIndexFrom1dIndex(get_thread_local_1d_id()); @@ -409,43 +416,66 @@ struct BlockwiseGenericTensorSliceCopy_v2 const auto thread_data_multi_id_begin = data_cluster_multi_id * SubLengths{}; - mThreadwiseCopy.SetSrcSliceOrigin(src_block_slice_origin + thread_data_multi_id_begin); - mThreadwiseCopy.SetDstSliceOrigin(dst_block_slice_origin + thread_data_multi_id_begin); + mThreadwiseLoad.SetSrcSliceOrigin(src_block_slice_origin + thread_data_multi_id_begin); + mThreadwiseLoad.SetDstSliceOrigin(make_zero_array()); + + mThreadwiseStore.SetSrcSliceOrigin(make_zero_array()); + mThreadwiseStore.SetDstSliceOrigin(dst_block_slice_origin + thread_data_multi_id_begin); + } + + __device__ static constexpr index_t GetRegisterBufferSize() + { + return RegisterBufferDesc::GetElementSpace(); + } + + __device__ void RunLoadRegisterBuffer(const TData* p_src, TData* p_buffer) const + { + mThreadwiseLoad.Run(p_src, p_buffer); + } + + __device__ void RunStoreRegisterBuffer(const TData* p_buffer, TData* p_dst) const + { + mThreadwiseStore.Run(p_buffer, p_dst); } __device__ void Run(const TData* p_src, TData* p_dst) const { - mThreadwiseCopy.Run(p_src, p_dst); - } + TData p_buffer[GetRegisterBufferSize()]; - template - __device__ void Run_hack(const TData* p_src, - TData* p_dst, - SrcMergedDimSubLengthsHack, - DstMergedDimSubLengthsHack) const - { - // hacks to isolate merged dimension from normal dimensions, and calculate their offset - // seperately - // SrcMergedDimSliceLengthsHack has entry same as SliceLengths on src merged dimensions, - // but 1 on normal dimensions; - // SrcNormalDimSliceLengthsHack has entry same as SliceLengths on src normal dimensions, - // but 1 on merged dimensions; - mThreadwiseCopy.Run_hack( - p_src, p_dst, SrcMergedDimSubLengthsHack{}, DstMergedDimSubLengthsHack{}); + mThreadwiseLoad.Run(p_src, p_buffer); + mThreadwiseStore.Run(p_buffer, p_dst); } __device__ void MoveSrcSlicingWindow(Array step_sizes, bool positive_direction) { - mThreadwiseCopy.MoveSrcSlicingWindow(step_sizes, positive_direction); + mThreadwiseLoad.MoveSrcSlicingWindow(step_sizes, positive_direction); } __device__ void MoveDstSlicingWindow(Array step_sizes, bool positive_direction) { - mThreadwiseCopy.MoveDstSlicingWindow(step_sizes, positive_direction); + mThreadwiseStore.MoveDstSlicingWindow(step_sizes, positive_direction); } - // private: - ThreadwiseCopy mThreadwiseCopy; + private: + using RegisterBufferDesc = decltype(make_ConstantTensorDescriptor_packed(SubLengths{})); + + using ThreadwiseLoad = + ThreadwiseGenericTensorSliceCopy_v2, + SubLengths>; + + using ThreadwiseStore = + ThreadwiseGenericTensorSliceCopy_v2, + DstCoordinate, + SubLengths>; + ThreadwiseLoad mThreadwiseLoad; + ThreadwiseStore mThreadwiseStore; }; } // namespace ck diff --git a/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp b/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp index 0eba4b2807..aa75a7fe6c 100644 --- a/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp @@ -138,47 +138,17 @@ struct ThreadwiseGenericTensorSliceCopy_v2 mDstSliceOrigin = dst_slice_origin; } - __device__ void Run(const TData* p_src, TData* p_dst) const + template + struct IsolateMergedDimSliceLengthsHack { - constexpr auto buffer_desc = make_ConstantTensorDescriptor_packed(SliceLengths{}); + template + __device__ constexpr index_t operator()(IDim idim) const + { + return TDesc::ContainMultipleOriginalDimensions(idim) ? Seq{}[idim] : 1; + } + }; - TData p_buffer_[buffer_desc.GetElementSpace()]; - TData* p_buffer = p_buffer_; - -#if 0 - static_ford{}([&](auto data_id) { - p_buffer[buffer_desc.GetOffsetFromMultiIndex(data_id)] = - p_src[(mSrcSliceOrigin + data_id).GetOffset()]; - }); - - static_ford{}([&](auto data_id) { - p_dst[(mDstSliceOrigin + data_id).GetOffset()] = - p_buffer[buffer_desc.GetOffsetFromMultiIndex(data_id)]; - }); -#elif 1 - auto src_slice_origin = mSrcSliceOrigin; - auto dst_slice_origin = mDstSliceOrigin; - - const TData* p_src_tmp = p_src + src_slice_origin.RepositionOrigin(); - TData* p_dst_tmp = p_dst + dst_slice_origin.RepositionOrigin(); - - static_ford{}([&](auto data_id) { - p_buffer[buffer_desc.GetOffsetFromMultiIndex(data_id)] = - p_src_tmp[(src_slice_origin + data_id).GetOffset()]; - }); - - static_ford{}([&](auto data_id) { - p_dst_tmp[(dst_slice_origin + data_id).GetOffset()] = - p_buffer[buffer_desc.GetOffsetFromMultiIndex(data_id)]; - }); -#endif - } - - template - __device__ void Run_hack(const TData* p_src, - TData* p_dst, - SrcMergedDimSliceLengthsHack, - DstMergedDimSliceLengthsHack) const + __device__ void Run(const TData* p_src, TData* p_dst) const { constexpr auto buffer_desc = make_ConstantTensorDescriptor_packed(SliceLengths{}); @@ -191,6 +161,10 @@ struct ThreadwiseGenericTensorSliceCopy_v2 // but 1 on normal dimensions; // SrcNormalDimSliceLengthsHack has entry same as SliceLengths on src normal dimensions, // but 1 on merged dimensions; + using SrcMergedDimSliceLengthsHack = + typename sequence_gen>::type; + using SrcNormalDimSliceLengthsHack = decltype((SliceLengths{} + Number<1>{}) - SrcMergedDimSliceLengthsHack{}); @@ -216,6 +190,10 @@ struct ThreadwiseGenericTensorSliceCopy_v2 // but 1 on normal dimensions; // DstNormalDimSliceLengthsHack has entry same as SliceLengths on dst normal dimensions, // but 1 on merged dimensions; + using DstMergedDimSliceLengthsHack = + typename sequence_gen>::type; + using DstNormalDimSliceLengthsHack = decltype((SliceLengths{} + Number<1>{}) - DstMergedDimSliceLengthsHack{}); diff --git a/composable_kernel/include/utility/Sequence.hpp b/composable_kernel/include/utility/Sequence.hpp index 4e410964c9..a597fa6e5a 100644 --- a/composable_kernel/include/utility/Sequence.hpp +++ b/composable_kernel/include/utility/Sequence.hpp @@ -128,48 +128,63 @@ struct sequence_merge, Sequence> using type = Sequence; }; -// arithmetic sqeuence -template -struct arithmetic_sequence_gen_impl +// generate sequence +template +struct sequence_gen_impl { - static constexpr index_t NSizeLeft = NSize / 2; + static constexpr index_t NRemainLeft = NRemain / 2; + static constexpr index_t NRemainRight = NRemain - NRemainLeft; + static constexpr index_t IMiddle = IBegin + NRemainLeft; - using type = typename sequence_merge< - typename arithmetic_sequence_gen_impl::type, - typename arithmetic_sequence_gen_impl::type>::type; + using type = + typename sequence_merge::type, + typename sequence_gen_impl::type>::type; }; -template -struct arithmetic_sequence_gen_impl +template +struct sequence_gen_impl { - using type = Sequence; + static constexpr index_t Is = F{}(Number{}); + using type = Sequence; }; -template -struct arithmetic_sequence_gen_impl +template +struct sequence_gen_impl { using type = Sequence<>; }; +template +struct sequence_gen +{ + using type = typename sequence_gen_impl<0, NSize, F>::type; +}; + +// arithmetic sequence template struct arithmetic_sequence_gen { - using type = typename arithmetic_sequence_gen_impl::type; + struct F + { + __host__ __device__ constexpr index_t operator()(index_t i) const + { + return i * Increment + IBegin; + } + }; + + using type = typename sequence_gen<(IEnd - IBegin) / Increment, F>::type; }; // uniform sequence template struct uniform_sequence_gen { - struct return_constant + struct F { __host__ __device__ constexpr index_t operator()(index_t) const { return I; } }; - using type = decltype( - typename arithmetic_sequence_gen<0, NSize, 1>::type{}.Transform(return_constant{})); + using type = typename sequence_gen::type; }; // reverse inclusive scan (with init) sequence diff --git a/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp index f79b17ae6d..67395b978d 100644 --- a/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp @@ -139,7 +139,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, for(index_t i = 0; i < nrepeat; ++i) { constexpr auto gridwise_conv = -#if 1 +#if 0 GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw #else GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer diff --git a/driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp index f6e9560385..529e51378c 100644 --- a/driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp @@ -4,6 +4,7 @@ #include "tensor.hpp" #include "gridwise_convolution_kernel_wrapper.hpp" #include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" +//#include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp" using namespace ck; diff --git a/driver/src/driver.cpp b/driver/src/driver.cpp index 8749fc1ae9..c9488b211a 100644 --- a/driver/src/driver.cpp +++ b/driver/src/driver.cpp @@ -532,7 +532,7 @@ int main(int argc, char* argv[]) #elif 0 device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw( (in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); -#elif 0 +#elif 1 device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc, in_nchw, wei_kcyx_desc,