From fdcfae3a62f106720c919681bcb9cc8a4fe83a69 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Tue, 6 Aug 2019 17:41:58 -0500 Subject: [PATCH] reimplement threadwise copy --- ...tion_implicit_gemm_v1r3_chwn_cyxk_khwn.hpp | 1 - ...tion_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp | 3 - ...tion_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp | 3 - ..._v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp | 3 - .../ConstantMergedTensorDescriptor.hpp | 12 +- .../blockwise_generic_tensor_slice_copy.hpp | 81 ++++++++---- .../threadwise_generic_tensor_slice_copy.hpp | 119 ++++++++++++++++-- .../include/utility/Sequence.hpp | 47 ++++++- ...tion_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp | 2 +- driver/src/driver.cpp | 2 +- 10 files changed, 223 insertions(+), 50 deletions(-) diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hpp index 625e0a1f85..ef608e6061 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hpp @@ -157,7 +157,6 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v2< BlockSize, - Float, decltype(in_c_h_w_n_global_desc), decltype(in_c_h_w_n_block_desc), NormalTensorCoordinate, diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp index b83a58bfad..7655fb16e4 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp @@ -176,7 +176,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw #else auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v2< BlockSize, - Float, decltype(in_e_n1_b_n2_global_merged_desc), decltype(in_e_n1_b_n2_block_desc), MergedTensorCoordinate, @@ -219,7 +218,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw #else auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v2< BlockSize, - Float, decltype(wei_e_k_global_desc), decltype(wei_e_k_block_desc), NormalTensorCoordinate, @@ -373,7 +371,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw Number<1>{}); #else ThreadwiseGenericTensorSliceCopy_v2< - Float, decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc), decltype(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc), NormalTensorCoordinate, diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp index 441b8c887e..4c20efb578 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp @@ -131,7 +131,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw // this copy operator already has blockwise offset built-in auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v2, @@ -158,7 +157,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw // this copy operator already have blockwise offset built-in auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v2< BlockSize, - Float, decltype(wei_e_k_global_desc), decltype(wei_e_k_block_desc), NormalTensorCoordinate, @@ -288,7 +286,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw Sequence; auto threadwise_out_copy = ThreadwiseGenericTensorSliceCopy_v2< - Float, decltype(out_k0_k1_b_thread_desc), decltype(out_k0_k1_b_global_desc), NormalTensorCoordinate, diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp index db6af6ac19..168109da56 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp @@ -131,7 +131,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer // this copy operator already has blockwise offset built-in auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v2, @@ -158,7 +157,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer // this copy operator already have blockwise offset built-in auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v2< BlockSize, - Float, decltype(wei_e_k_global_desc), decltype(wei_e_k_block_desc), NormalTensorCoordinate, @@ -352,7 +350,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer Sequence; auto threadwise_out_copy = ThreadwiseGenericTensorSliceCopy_v2< - Float, decltype(out_k0_k1_b_thread_desc), decltype(out_k0_k1_b_global_desc), NormalTensorCoordinate, diff --git a/composable_kernel/include/tensor_description/ConstantMergedTensorDescriptor.hpp b/composable_kernel/include/tensor_description/ConstantMergedTensorDescriptor.hpp index 75b46cecce..01653ffb1f 100644 --- a/composable_kernel/include/tensor_description/ConstantMergedTensorDescriptor.hpp +++ b/composable_kernel/include/tensor_description/ConstantMergedTensorDescriptor.hpp @@ -65,11 +65,21 @@ struct ConstantMergedTensorDescriptor static_assert(!ContainMultipleOriginalDimensions(Number{}), "wrong! stride of a merged dimension is undefined"); - constexpr auto idim_original = std::get(mOriginalDimMergeSeqs).Front(); + constexpr auto idim_original = std::get(mOriginalDimMergeSeqs).Back(); return OriginalTensorDesc::GetStride(Number{}); } + // this is a hack to return the stride of the last original dimension of a merged dimension + // TODO: refactor this once the concept of "dimension" is used + template + __host__ __device__ static constexpr auto GetLastOriginalDimensionStride(Number) + { + constexpr auto idim_last_original = std::get(mOriginalDimMergeSeqs).Back(); + + return OriginalTensorDesc::GetStride(Number{}); + } + __host__ __device__ static constexpr auto GetLengths() { return Sequence{}; diff --git a/composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp b/composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp index 4f225f5e60..ed3049341e 100644 --- a/composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp @@ -13,11 +13,13 @@ namespace ck { -// slice a (normal or merged) tensor, and copy it into another (normal or merged) tensor +// Slice a (normal or merged) tensor, and copy it into another (normal or merged) tensor // memory layout (ordering of dimensions) can be different between src and dst. -// on a merged dimension that constains multiple original dimensions, -// its sub-length need to evenly divide the length of the last original dimension -// so each thread is effectively reading a normal (not merged) tensor +// This functions assume each thread is reading and writing a normal (not merged) tensor, +// to simplify index calculations. To satisfy this assumption, the user need to make sure +// that, on a merged dimension that constains multiple original dimensions, the length of +// the last original dimension need to be evenly dividable by its sub-lengths. Also, the +// repeat-length on the merged dimension need to be 1. template {}([&](auto IDim) { - static_assert(SliceLengths::Get(IDim) % SubLengths::Get(IDim) == 0, - "wrong! cannot evenly divide sliced tensor into sub-tensor"); - static_assert(SliceLengths::Get(IDim) % data_per_cluster_per_dims.Get(IDim) == 0, "wrong! cannot evenly divide sliced tensor into cluster"); }); - // on a merged dimension that constains multiple original dimensions, - // its sub-length need to evenly divide the length of the last original dimension, - // so each thread is effectively reading a normal (not merged) tensor - static_for<0, nDim, 1>{}([&](auto IDim) { - constexpr auto sub_length = SubLengths::Get(IDim); + constexpr auto repeat_lengths = SliceLengths{} / data_per_cluster_per_dims; - constexpr auto idim_original_src = SrcDesc::GetContainedOriginalDimensions(IDim).Back(); - static_assert(SrcDesc::GetOriginalTensorDescriptor().GetLength(idim_original_src) % - sub_length == - 0, - "wrong!"); + // additional check for merged dimension + static_for<0, nDim, 1>{}([&](auto IDim_) { + // src + static_if{}([&](auto) { + constexpr auto IDim = decltype(IDim_){}; - constexpr auto idim_original_dst = DstDesc::GetContainedOriginalDimensions(IDim).Back(); - static_assert(DstDesc::GetOriginalTensorDescriptor().GetLength(idim_original_dst) % - sub_length == - 0, - "wrong!"); + // on a merged dimension that constains multiple original dimensions, + // the length of the last original dimension need to evenly dividable by its + // sub-length, + // so each thread is effectively reading a normal (not merged) tensor + constexpr auto idim_last_original_src = + SrcDesc::GetContainedOriginalDimensions(IDim).Back(); + static_assert( + SrcDesc::GetOriginalTensorDescriptor().GetLength(idim_last_original_src) % + SubLengths::Get(IDim) == + 0, + "wrong!"); + + // merged dimension should have repeat_lengths = 1 + static_assert(repeat_lengths[IDim] == 1, + "wrong! repeat_lengths shoud be 1 on merged dimension"); + }); + + // dst + static_if{}([&](auto) { + constexpr auto IDim = decltype(IDim_){}; + + // on a merged dimension that constains multiple original dimensions, + // the length of the last original dimension need to evenly dividable by its + // sub-length, + // so each thread is effectively reading a normal (not merged) tensor + constexpr auto idim_last_original_dst = + DstDesc::GetContainedOriginalDimensions(IDim).Back(); + static_assert( + DstDesc::GetOriginalTensorDescriptor().GetLength(idim_last_original_dst) % + SubLengths::Get(IDim) == + 0, + "wrong!"); + + // merged dimension should have repeat_lengths = 1 + static_assert(repeat_lengths[IDim] == 1, + "wrong! repeat_lengths shoud be 1 on merged dimension"); + }); }); // calculate mThreadSrcOffset, mThreadDstOffset @@ -376,7 +403,6 @@ struct BlockwiseGenericTensorSliceCopy_v1 }; template __device__ void RunLoadRegisterBuffer(const TData* p_src, TData* p_buffer) const { mThreadwiseLoad.Run(p_src, p_buffer); } + template __device__ void RunStoreRegisterBuffer(const TData* p_buffer, TData* p_dst) const { mThreadwiseStore.Run(p_buffer, p_dst); } + template __device__ void Run(const TData* p_src, TData* p_dst) const { TData p_buffer[GetRegisterBufferSize()]; @@ -466,16 +495,14 @@ struct BlockwiseGenericTensorSliceCopy_v2 using RegisterBufferDesc = decltype(make_ConstantTensorDescriptor_packed(SubLengths{})); using ThreadwiseLoad = - ThreadwiseGenericTensorSliceCopy_v2, SubLengths>; using ThreadwiseStore = - ThreadwiseGenericTensorSliceCopy_v2, DstCoordinate, diff --git a/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp b/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp index 10f41178df..b3e659970d 100644 --- a/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp @@ -106,8 +106,107 @@ __device__ void threadwise_generic_tensor_slice_copy_v1( #endif } -template +struct ThreadwiseGenericTensorSliceCopy_v1 +{ + static constexpr index_t nDim = SliceLengths::GetNumOfDimension(); + + __device__ constexpr ThreadwiseGenericTensorSliceCopy_v1(Array src_slice_origin, + Array dst_slice_origin) + : mSrcSliceOrigin(src_slice_origin), mDstSliceOrigin(dst_slice_origin) + { + static_assert(nDim == SrcDesc::GetNumOfDimension() && + nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::GetSize() && + nDim == SrcDimAccessOrder::GetSize() && + nDim == DstDimAccessOrder::GetSize(), + "wrong! # of dimensions not the same"); + + static_assert(is_valid_sequence_map::{} && + is_valid_sequence_map::{}, + "wrong! map is not valid"); + + static_assert(SliceLengths{}[SrcVectorDim] % SrcDataPerAccess == 0 && + SliceLengths{DstVectorDim} % DstDataPerAccess == 0, + "wrong! cannot evenly divide"); + + // check vectorized memory access + constexpr auto src_vector_access_dim = Number{}; + constexpr auto dst_vector_access_dim = Number{}; + + static_if{}([&](auto fwd) { + static_assert( + (fwd(SrcDesc{}).GetStrides()[SrcVectorAccessDim] == 1 || SrcDataPerAccess == 1), + "wrong! vectorized access is allowed only if stride == 1"); + }).Else{}([&](auto fwd) { + static_assert((SrcDesc::GetLastOriginalDimensionStride(src_vector_access_dim) == 1 || + SrcDataPerAccess == 1), + "wrong! vectorized access is allowed only if stride == 1"); + }); + + static_if{}([&](auto fwd) { + static_assert( + (fwd(DstDesc{}).GetStrides()[DstVectorAccessDim] == 1 || DstDataPerAccess == 1), + "wrong! vectorized access is allowed only if stride == 1"); + }).Else{}([&](auto fwd) { + static_assert((DstDesc::GetLastOriginalDimensionStride(dst_vector_access_dim) == 1 || + DstDataPerAccess == 1), + "wrong! vectorized access is allowed only if stride == 1"); + }); + } + + __device__ constexpr ThreadwiseGenericTensorSliceCopy_v1() + : ThreadwiseGenericTensorSliceCopy_v1(make_zero_array(), + make_zero_array()) + { + } + + __device__ void SetSrcSliceOrigin(Array src_slice_origin) + { + mSrcSliceOrigin = src_slice_origin; + } + + __device__ void SetDstSliceOrigin(Array dst_slice_origin) + { + mDstSliceOrigin = dst_slice_origin; + } + + template + __device__ void Run(const TData* p_src, TData* p_dst) const + { + constexpr auto buffer_desc = make_ConstantTensorDescriptor_packed(SliceLengths{}); + + TData p_buffer[buffer_desc.GetElementSpace()]; + + // copy data from src into buffer + constexpr auto src_vector_access_dim = Number{}; + + constexpr auto src_access_lengths = SliceLengths::Modify( + src_vector_access_dim, SliceLengths::Get(src_vector_access_dim) / SrcDataPerAccess); + + constexpr auto src_access_lengths_in_src_access_order = + src_access_lengths.ReorderGivenNew2Old(SrcDimAccessOrder{}); + + static_ford{}([&](auto src_access_id) {}); + } + + private: + Array mSrcSliceOrigin; + Array mDstSliceOrigin; +}; +#endif + +template ()), - mDstSliceOrigin(make_zero_array()) - { - } - __device__ constexpr ThreadwiseGenericTensorSliceCopy_v2(SrcCoordinate src_slice_origin, DstCoordinate dst_slice_origin) : mSrcSliceOrigin(src_slice_origin), mDstSliceOrigin(dst_slice_origin) { } + __device__ constexpr ThreadwiseGenericTensorSliceCopy_v2() + : ThreadwiseGenericTensorSliceCopy_v2(make_zero_array(), + make_zero_array()) + { + } + __device__ void SetSrcSliceOrigin(SrcCoordinate src_slice_origin) { mSrcSliceOrigin = src_slice_origin; @@ -148,6 +247,7 @@ struct ThreadwiseGenericTensorSliceCopy_v2 } }; + template __device__ void Run(const TData* p_src, TData* p_dst) const { constexpr auto buffer_desc = make_ConstantTensorDescriptor_packed(SliceLengths{}); @@ -216,6 +316,7 @@ struct ThreadwiseGenericTensorSliceCopy_v2 }); } + // T can be Sequence or Array template __device__ void MoveSrcSlicingWindow(T step_sizes, integral_constant) { @@ -232,7 +333,7 @@ struct ThreadwiseGenericTensorSliceCopy_v2 }).Else([&](auto) { mDstSliceOrigin -= step_sizes; }); } - // private: + private: SrcCoordinate mSrcSliceOrigin; DstCoordinate mDstSliceOrigin; }; diff --git a/composable_kernel/include/utility/Sequence.hpp b/composable_kernel/include/utility/Sequence.hpp index a597fa6e5a..44cfd669db 100644 --- a/composable_kernel/include/utility/Sequence.hpp +++ b/composable_kernel/include/utility/Sequence.hpp @@ -6,9 +6,12 @@ namespace ck { -template +template struct is_valid_sequence_map; +template +struct sequence_map_inverse; + template struct Sequence { @@ -34,6 +37,8 @@ struct Sequence return Number{})>{}; } + __host__ __device__ static constexpr auto Get(index_t I) { return GetImpl(I); } + template __host__ __device__ constexpr auto operator[](Number) const { @@ -54,6 +59,18 @@ struct Sequence return Sequence{})...>{}; } + // MapOld2New is Sequence<...> + template + __host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New) + { + static_assert(MapOld2New::GetSize() == GetSize(), + "wrong! reorder map should have the same size as Sequence to be rerodered"); + + static_assert(is_valid_sequence_map::value, "wrong! invalid reorder map"); + + return ReorderGivenNew2Old(typename sequence_map_inverse::type{}); + } + __host__ __device__ static constexpr auto Reverse(); __host__ __device__ static constexpr auto Front() @@ -253,6 +270,7 @@ struct sequence_reverse> template struct is_valid_sequence_map { + // not implemented yet, always return true static constexpr integral_constant value = integral_constant{}; // TODO: add proper check for is_valid, something like: @@ -261,6 +279,33 @@ struct is_valid_sequence_map // typename sequence_sort::SortedSeqType>{}; }; +template +struct sequence_map_inverse_impl +{ + private: + static constexpr auto new_y2x = WorkingY2X::Modify(X2Y{}[XBegin], XBegin); + + public: + using type = + typename sequence_map_inverse_impl::type; +}; + +template +struct sequence_map_inverse_impl +{ + using type = WorkingY2X; +}; + +template +struct sequence_map_inverse +{ + using type = + typename sequence_map_inverse_impl::type, + 0, + X2Y::GetSize()>::type; +}; + template __host__ __device__ constexpr auto operator+(Sequence, Sequence) { diff --git a/driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp index 25c0bf2602..e1f950739a 100644 --- a/driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp @@ -132,7 +132,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); constexpr auto gridwise_conv = -#if 1 +#if 0 GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw #else GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer diff --git a/driver/src/driver.cpp b/driver/src/driver.cpp index 4a75628952..540f81186c 100644 --- a/driver/src/driver.cpp +++ b/driver/src/driver.cpp @@ -379,7 +379,7 @@ int main(int argc, char* argv[]) #elif 0 device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw( (in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); -#elif 0 +#elif 1 device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc, in_nchw, wei_kcyx_desc,