diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp index dc236f0473..316f1e46b5 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp @@ -170,6 +170,8 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer InBlockCopyThreadClusterArrangeOrder, InBlockCopySrcAccessOrder, InBlockCopyDstAccessOrder, + 2, + 3, InBlockCopySrcDataPerRead_B, InBlockCopyDstDataPerWrite_N2>( {0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0}); @@ -213,6 +215,8 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer WeiBlockCopyThreadClusterArrangeOrder, WeiBlockCopySrcAccessOrder, WeiBlockCopyDstAccessOrder, + 0, + 1, WeiBlockCopySrcDataPerRead_E, WeiBlockCopyDstDataPerWrite_K>( {0, k_block_data_on_global}, {0, 0}); @@ -434,7 +438,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex( k_thread_data_on_global, 0, b_thread_data_on_global, 0); -#if 1 +#if 0 threadwise_generic_tensor_slice_copy_v1( out_n0_n1_n2_k0_k1_k2_h_w_thread_desc, p_out_thread, @@ -445,9 +449,20 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths(), arithmetic_sequence_gen<0, 8, 1>::type{}, Number<1>{}); -#else +#elif 1 + ThreadwiseGenericTensorSliceCopy_v1< + decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc), + decltype(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc), + decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths()), + arithmetic_sequence_gen<0, 8, 1>::type, + arithmetic_sequence_gen<0, 8, 1>::type, + 0, + 0, + 1, + 1>({0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0}) + .Run(p_out_thread, p_out_thread_on_global); +#elif 0 ThreadwiseGenericTensorSliceCopy_v2< - Float, decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc), decltype(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc), NormalTensorCoordinate, diff --git a/composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp b/composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp index 0c1e5af052..4a1acd102b 100644 --- a/composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp @@ -19,7 +19,8 @@ namespace ck { // to simplify index calculations. To satisfy this assumption, the user need to make sure // that, on a merged dimension that constains multiple original dimensions, the length of // the last original dimension need to be evenly dividable by its sub-lengths. Also, the -// repeat-length on the merged dimension need to be 1. +// repeat-length on the merged dimension need to be 1. These sanity checks are performed +// in constructor of BlockwiseGenericTensorSliceCopy_v1 template + class SrcDimAccessOrder, + class DstDimAccessOrder, + index_t SrcVectorAccessDim, + index_t DstVectorAccessDim, + index_t SrcDataPerAccess, + index_t DstDataPerAccess> struct BlockwiseGenericTensorSliceCopy_v1 { static constexpr index_t nDim = SrcDesc::GetNumOfDimension(); @@ -60,23 +63,22 @@ struct BlockwiseGenericTensorSliceCopy_v1 Array mThreadSrcOriginalMultiId; Array mThreadDstOriginalMultiId; - __device__ - BlockwiseGenericTensorSliceCopy_v1(Array src_block_data_multi_id_begin, - Array dst_block_data_multi_id_begin) + __device__ BlockwiseGenericTensorSliceCopy_v1(Array src_block_data_id_begin, + Array dst_block_data_id_begin) { // check NDim consistency - static_assert(nDim == SrcDesc::GetNumOfDimension() && - nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::GetSize() && - nDim == SubLengths::GetSize() && - nDim == ThreadClusterLengths::GetSize() && - nDim == ThreadClusterArrangeOrder::GetSize() && - nDim == SrcAccessOrder::GetSize() && nDim == DstAccessOrder::GetSize(), - "wrong"); + static_assert( + nDim == SrcDesc::GetNumOfDimension() && nDim == DstDesc::GetNumOfDimension() && + nDim == SliceLengths::GetSize() && nDim == SubLengths::GetSize() && + nDim == ThreadClusterLengths::GetSize() && + nDim == ThreadClusterArrangeOrder::GetSize() && + nDim == SrcDimAccessOrder::GetSize() && nDim == DstDimAccessOrder::GetSize(), + "wrong"); // check thread arrange order and read/write access order are valid static_assert(is_valid_sequence_map::value && - is_valid_sequence_map::value && - is_valid_sequence_map::value, + is_valid_sequence_map::value && + is_valid_sequence_map::value, "wrong!"); // thread cluster @@ -142,20 +144,20 @@ struct BlockwiseGenericTensorSliceCopy_v1 }); // calculate mThreadSrcOffset, mThreadDstOffset - const auto thread_cluster_multi_id = + const auto thread_cluster_id = thread_cluster_desc.GetMultiIndexFrom1dIndex(get_thread_local_1d_id()); - const auto data_cluster_multi_id = - reorder_array_given_old2new(thread_cluster_multi_id, ThreadClusterArrangeOrder{}); + const auto data_cluster_id = + reorder_array_given_old2new(thread_cluster_id, ThreadClusterArrangeOrder{}); - const auto thread_data_multi_id_begin = data_cluster_multi_id * SubLengths{}; + const auto thread_data_id_begin = data_cluster_id * SubLengths{}; // original multi-id mThreadSrcOriginalMultiId = SrcDesc::GetOriginalMultiIndexFromMultiIndex( - src_block_data_multi_id_begin + thread_data_multi_id_begin); + src_block_data_id_begin + thread_data_id_begin); mThreadDstOriginalMultiId = DstDesc::GetOriginalMultiIndexFromMultiIndex( - dst_block_data_multi_id_begin + thread_data_multi_id_begin); + dst_block_data_id_begin + thread_data_id_begin); // partial offset on each dimension static_for<0, nDim, 1>{}([&](auto IDim) { @@ -188,14 +190,16 @@ struct BlockwiseGenericTensorSliceCopy_v1 mThreadDstPartialOffsets, math::plus{}, static_cast(0)); } - __device__ static constexpr index_t GetRegisterBufferSize() + __device__ static constexpr auto GetRegisterBufferDescriptor() { constexpr auto repeat_lengths = SliceLengths{} / (SubLengths{} * ThreadClusterLengths{}); - constexpr auto thread_tensor_desc = - make_ConstantTensorDescriptor_packed(SubLengths{} * repeat_lengths); + return make_ConstantTensorDescriptor_packed(SubLengths{} * repeat_lengths); + } - return thread_tensor_desc.GetElementSpace(); + __device__ static constexpr index_t GetRegisterBufferSize() + { + return GetRegisterBufferDescriptor().GetElementSpace(); } __device__ void RunLoadRegisterBuffer(const Float* __restrict__ p_src, @@ -208,50 +212,62 @@ struct BlockwiseGenericTensorSliceCopy_v1 constexpr auto repeat_lengths = SliceLengths{} / (SubLengths{} * ThreadClusterLengths{}); - constexpr auto thread_tensor_desc = - make_ConstantTensorDescriptor_packed(thread_sub_tensor_lengths * repeat_lengths); + constexpr auto thread_buffer_desc = GetRegisterBufferDescriptor(); #if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 - static_ford{}([&](auto repeat_multi_id) { - constexpr auto src_thread_data_multi_id_begin = - repeat_multi_id * data_per_cluster_per_dims; + static_ford{}([&](auto repeat_id) { + constexpr auto src_thread_data_id_begin = repeat_id * data_per_cluster_per_dims; - constexpr auto buffer_data_multi_id_begin = repeat_multi_id * thread_sub_tensor_lengths; + constexpr auto buffer_data_id_begin = repeat_id * thread_sub_tensor_lengths; constexpr index_t src_offset = - SrcDesc::GetOffsetFromMultiIndex(src_thread_data_multi_id_begin); + SrcDesc::GetOffsetFromMultiIndex(src_thread_data_id_begin); constexpr index_t buffer_offset = - thread_tensor_desc.GetOffsetFromMultiIndex(buffer_data_multi_id_begin); + thread_buffer_desc.GetOffsetFromMultiIndex(buffer_data_id_begin); #else - ford{}([&](auto repeat_multi_id) { - const auto src_thread_data_multi_id_begin = repeat_multi_id * data_per_cluster_per_dims; + ford{}([&](auto repeat_id) { + const auto src_thread_data_id_begin = repeat_id * data_per_cluster_per_dims; - const auto buffer_data_multi_id_begin = repeat_multi_id * thread_sub_tensor_lengths; + const auto buffer_data_id_begin = repeat_id * thread_sub_tensor_lengths; - const index_t src_offset = - SrcDesc::GetOffsetFromMultiIndex(src_thread_data_multi_id_begin); + const index_t src_offset = SrcDesc::GetOffsetFromMultiIndex(src_thread_data_id_begin); const index_t buffer_offset = - thread_tensor_desc.GetOffsetFromMultiIndex(buffer_data_multi_id_begin); + thread_buffer_desc.GetOffsetFromMultiIndex(buffer_data_id_begin); #endif - // By position the origin of the per-thread window at the point, where multi-index - // of the SrcDesc (might be a merged tensor) is all-zero. This threadwise slice copy - // is assuming each thread is copy a noraml (not merged) tensor. - // User need to guarantee this is true. - // By setting SubLengths = 1 at the merged dimension, this is always true; - // If in the future, you want to enable SubLengths > 1 at the merged dimension, - // special care in implementation is needed +// By position the origin of the per-thread window at the point, where multi-index +// of the SrcDesc (might be a merged tensor) is all-zero. This threadwise slice copy +// is assuming each thread is copy a noraml (not merged) tensor. +// To satisfy this assumption, the user need to make sure that, on a merged dimension +// that constains multiple original dimensions, the length of the last original +// dimension need to be evenly dividable by its sub-lengths. Also, the repeat-length on +// the merged dimension need to be 1. These sanity checks are performed in constructor +// of BlockwiseGenericTensorSliceCopy_v1 +#if 0 // debug threadwise_generic_tensor_slice_copy_v1(SrcDesc{}, p_src + src_offset + mThreadSrcOffset, make_zero_array(), - thread_tensor_desc, + thread_buffer_desc, p_buffer + buffer_offset, make_zero_array(), thread_sub_tensor_lengths, - SrcAccessOrder{}, - Number{}); + SrcDimAccessOrder{}, + Number{}); +#else + ThreadwiseGenericTensorSliceCopy_v1::type, + SrcVectorAccessDim, + 0, + SrcDataPerAccess, + 1>(make_zero_array(), + make_zero_array()) + .Run(p_src + src_offset + mThreadSrcOffset, p_buffer + buffer_offset); +#endif }); } @@ -265,48 +281,60 @@ struct BlockwiseGenericTensorSliceCopy_v1 constexpr auto repeat_lengths = SliceLengths{} / (SubLengths{} * ThreadClusterLengths{}); - constexpr auto thread_tensor_desc = - make_ConstantTensorDescriptor_packed(thread_sub_tensor_lengths * repeat_lengths); + constexpr auto thread_buffer_desc = GetRegisterBufferDescriptor(); #if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 - static_ford{}([&](auto repeat_multi_id) { - constexpr auto buffer_data_multi_id_begin = repeat_multi_id * thread_sub_tensor_lengths; + static_ford{}([&](auto repeat_id) { + constexpr auto buffer_data_id_begin = repeat_id * thread_sub_tensor_lengths; - constexpr auto dst_data_multi_id_begin = repeat_multi_id * data_per_cluster_per_dims; + constexpr auto dst_data_id_begin = repeat_id * data_per_cluster_per_dims; constexpr index_t buffer_offset = - thread_tensor_desc.GetOffsetFromMultiIndex(buffer_data_multi_id_begin); + thread_buffer_desc.GetOffsetFromMultiIndex(buffer_data_id_begin); - constexpr index_t dst_offset = - DstDesc::GetOffsetFromMultiIndex(dst_data_multi_id_begin); + constexpr index_t dst_offset = DstDesc::GetOffsetFromMultiIndex(dst_data_id_begin); #else - ford{}([&](auto repeat_multi_id) { - const auto buffer_data_multi_id_begin = repeat_multi_id * thread_sub_tensor_lengths; + ford{}([&](auto repeat_id) { + const auto buffer_data_id_begin = repeat_id * thread_sub_tensor_lengths; - const auto dst_data_multi_id_begin = repeat_multi_id * data_per_cluster_per_dims; + const auto dst_data_id_begin = repeat_id * data_per_cluster_per_dims; const index_t buffer_offset = - thread_tensor_desc.GetOffsetFromMultiIndex(buffer_data_multi_id_begin); + thread_buffer_desc.GetOffsetFromMultiIndex(buffer_data_id_begin); - const index_t dst_offset = DstDesc::GetOffsetFromMultiIndex(dst_data_multi_id_begin); + const index_t dst_offset = DstDesc::GetOffsetFromMultiIndex(dst_data_id_begin); #endif - // By position the origin of the per-thread window at the point, where multi-index - // of the SrcDesc (might be a merged tensor) is all-zero. This threadwise slice copy - // is assuming each thread is copy a noraml (not merged) tensor. - // User need to guarantee this is true. - // By setting SubLengths = 1 at the merged dimension, this is always true; - // If in the future, you want to enable SubLengths > 1 at the merged dimension, - // special care in implementation is needed - threadwise_generic_tensor_slice_copy_v1(thread_tensor_desc, +// By position the origin of the per-thread window at the point, where multi-index +// of the SrcDesc (might be a merged tensor) is all-zero. This threadwise slice copy +// is assuming each thread is copy a noraml (not merged) tensor. +// User need to guarantee this is true. +// By setting SubLengths = 1 at the merged dimension, this is always true; +// If in the future, you want to enable SubLengths > 1 at the merged dimension, +// special care in implementation is needed +#if 0 // debug + threadwise_generic_tensor_slice_copy_v1(thread_buffer_desc, p_buffer + buffer_offset, make_zero_array(), DstDesc{}, p_dst + dst_offset + mThreadDstOffset, make_zero_array(), thread_sub_tensor_lengths, - DstAccessOrder{}, - Number{}); + DstDimAccessOrder{}, + Number{}); +#else + ThreadwiseGenericTensorSliceCopy_v1::type, + DstDimAccessOrder, + 0, + DstVectorAccessDim, + 1, + DstDataPerAccess>(make_zero_array(), + make_zero_array()) + .Run(p_buffer + buffer_offset, p_dst + dst_offset + mThreadDstOffset); +#endif }); } @@ -346,26 +374,25 @@ struct BlockwiseGenericTensorSliceCopy_v1 SrcDesc::GetOriginalTensorDescriptor().Extract(src_partial_original_dims); // calculate new partial original multi-id - auto old_src_partial_original_multi_id = + auto old_src_partial_original_id = extract_array(mThreadSrcOriginalMultiId, src_partial_original_dims); - auto new_src_partial_original_multi_id = + auto new_src_partial_original_id = src_partial_original_desc.UpdateMultiIndexGivenStepSizeOf1dIndex( - old_src_partial_original_multi_id, StepSize, direction); + old_src_partial_original_id, StepSize, direction); // update "mThreadSrcOriginalMultiId" static_for<0, decltype(src_partial_original_dims)::GetSize(), 1>{}([&](auto I) { constexpr auto IDimOriginal = src_partial_original_dims[I]; - mThreadSrcOriginalMultiId(IDimOriginal) = new_src_partial_original_multi_id[I]; + mThreadSrcOriginalMultiId(IDimOriginal) = new_src_partial_original_id[I]; }); // calculate new partial offset on this merged dimension const index_t old_src_partial_offset = mThreadSrcPartialOffsets[IDim]; const index_t new_src_partial_offset = - src_partial_original_desc.GetOffsetFromMultiIndex( - new_src_partial_original_multi_id); + src_partial_original_desc.GetOffsetFromMultiIndex(new_src_partial_original_id); // update "mThreadSrcPartialOffsets" mThreadSrcPartialOffsets(IDim) = new_src_partial_offset; @@ -434,19 +461,19 @@ struct BlockwiseGenericTensorSliceCopy_v2 static_assert(BlockSize == thread_cluster_desc.GetElementSize(), "wrong! BlockSize not consistent with ThreadClusterLengths"); - const auto thread_cluster_multi_id = + const auto thread_cluster_id = thread_cluster_desc.GetMultiIndexFrom1dIndex(get_thread_local_1d_id()); - const auto data_cluster_multi_id = - reorder_array_given_old2new(thread_cluster_multi_id, ThreadClusterArrangeOrder{}); + const auto data_cluster_id = + reorder_array_given_old2new(thread_cluster_id, ThreadClusterArrangeOrder{}); - const auto thread_data_multi_id_begin = data_cluster_multi_id * SubLengths{}; + const auto thread_data_id_begin = data_cluster_id * SubLengths{}; - mThreadwiseLoad.SetSrcSliceOrigin(src_block_slice_origin + thread_data_multi_id_begin); + mThreadwiseLoad.SetSrcSliceOrigin(src_block_slice_origin + thread_data_id_begin); mThreadwiseLoad.SetDstSliceOrigin(make_zero_array()); mThreadwiseStore.SetSrcSliceOrigin(make_zero_array()); - mThreadwiseStore.SetDstSliceOrigin(dst_block_slice_origin + thread_data_multi_id_begin); + mThreadwiseStore.SetDstSliceOrigin(dst_block_slice_origin + thread_data_id_begin); } __device__ static constexpr index_t GetRegisterBufferSize() diff --git a/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp b/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp index d7812f8680..ce620bcf88 100644 --- a/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp @@ -106,7 +106,7 @@ __device__ void threadwise_generic_tensor_slice_copy_v1( #endif } -#if 0 +#if 1 template struct ThreadwiseGenericTensorSliceCopy_v1 { - static constexpr index_t nDim = SliceLengths::GetNumOfDimension(); + static constexpr index_t nDim = SliceLengths::GetSize(); __device__ constexpr ThreadwiseGenericTensorSliceCopy_v1(Array src_slice_origin, Array dst_slice_origin) @@ -130,39 +130,43 @@ struct ThreadwiseGenericTensorSliceCopy_v1 nDim == DstDimAccessOrder::GetSize(), "wrong! # of dimensions not the same"); - static_assert(is_valid_sequence_map::{} && - is_valid_sequence_map::{}, + static_assert(is_valid_sequence_map::value && + is_valid_sequence_map::value, "wrong! map is not valid"); - static_assert(SliceLengths{}[SrcVectorDim] % SrcDataPerAccess == 0 && - SliceLengths{DstVectorDim} % DstDataPerAccess == 0, + static_assert(SliceLengths{}[SrcVectorAccessDim] % SrcDataPerAccess == 0 && + SliceLengths{}[DstVectorAccessDim] % DstDataPerAccess == 0, "wrong! cannot evenly divide"); // check vectorized memory access - constexpr auto src_vector_access_dim = Number{}; - constexpr auto dst_vector_access_dim = Number{}; + constexpr auto src_vector_access_dim = Number{}; + constexpr auto dst_vector_access_dim = Number{}; - static_if{}([&](auto fwd) { - static_assert( - (fwd(SrcDesc{}).GetStrides()[SrcVectorAccessDim] == 1 || SrcDataPerAccess == 1), - "wrong! vectorized access is allowed only if stride == 1"); - }).Else{}([&](auto fwd) { - static_assert((SrcDesc::GetLastOriginalDimensionStride(src_vector_access_dim) == 1 || - SrcDataPerAccess == 1), - "wrong! vectorized access is allowed only if stride == 1"); - }); + static_if{}( + [&](auto fwd) { + static_assert( + (fwd(SrcDesc{}).GetStrides()[SrcVectorAccessDim] == 1 || SrcDataPerAccess == 1), + "wrong! vectorized access is allowed only if stride == 1"); + }) + .Else([&](auto fwd) { + static_assert( + (fwd(SrcDesc{}).GetLastOriginalDimensionStride(src_vector_access_dim) == 1 || + SrcDataPerAccess == 1), + "wrong! vectorized access is allowed only if stride == 1"); + }); - static_if{}([&](auto fwd) { - static_assert( - (fwd(DstDesc{}).GetStrides()[DstVectorAccessDim] == 1 || DstDataPerAccess == 1), - "wrong! vectorized access is allowed only if stride == 1"); - }).Else{}([&](auto fwd) { - static_assert((DstDesc::GetLastOriginalDimensionStride(dst_vector_access_dim) == 1 || - DstDataPerAccess == 1), - "wrong! vectorized access is allowed only if stride == 1"); - }); + static_if{}( + [&](auto fwd) { + static_assert( + (fwd(DstDesc{}).GetStrides()[DstVectorAccessDim] == 1 || DstDataPerAccess == 1), + "wrong! vectorized access is allowed only if stride == 1"); + }) + .Else([&](auto fwd) { + static_assert( + (fwd(DstDesc{}).GetLastOriginalDimensionStride(dst_vector_access_dim) == 1 || + DstDataPerAccess == 1), + "wrong! vectorized access is allowed only if stride == 1"); + }); } __device__ constexpr ThreadwiseGenericTensorSliceCopy_v1() @@ -186,23 +190,87 @@ struct ThreadwiseGenericTensorSliceCopy_v1 { constexpr auto buffer_desc = make_ConstantTensorDescriptor_packed(SliceLengths{}); - TData p_buffer[buffer_desc.GetElementSpace()]; + TData p_buffer_[buffer_desc.GetElementSpace()]; + TData* p_buffer = p_buffer_; // copy data from src into buffer - constexpr auto src_vector_access_dim = Number{}; + { + using vector_t = typename vector_type::MemoryType; - constexpr auto src_access_lengths = SliceLengths::Modify( - src_vector_access_dim, SliceLengths::Get(src_vector_access_dim) / SrcDataPerAccess); + constexpr auto src_vector_access_dim = Number{}; + constexpr auto src_data_per_access = Number{}; - constexpr auto src_access_lengths_in_src_access_order = - src_access_lengths.ReorderGivenNew2Old(SrcDimAccessOrder{}); + constexpr auto src_access_lengths = SliceLengths::Modify( + src_vector_access_dim, + SliceLengths::Get(src_vector_access_dim) / src_data_per_access); - static_ford{}([&](auto src_access_id) {}); + static_ford{}([&](auto src_access_id) { + constexpr auto src_data_id = src_access_id.Modify( + src_vector_access_dim, + src_access_id[src_vector_access_dim] * src_data_per_access); + + const index_t src_offset = + SrcDesc::GetOffsetFromMultiIndex(mSrcSliceOrigin + src_data_id); + + // load vector from src + const vector_t vector_data = *reinterpret_cast(&p_src[src_offset]); + + // unpack vector into buffer + static_for<0, SrcDataPerAccess, 1>{}([&](auto i) { + constexpr auto scalar_id = + typename uniform_sequence_gen::type{}.Modify(src_vector_access_dim, + i); + + constexpr index_t buffer_offset = + buffer_desc.GetOffsetFromMultiIndex(src_data_id + scalar_id); + + p_buffer[buffer_offset] = reinterpret_cast(&vector_data)[i]; + }); + }); + } + + // copy data from buffer to dst + { + using vector_t = typename vector_type::MemoryType; + + constexpr auto dst_vector_access_dim = Number{}; + constexpr auto dst_data_per_access = Number{}; + + constexpr auto dst_access_lengths = SliceLengths::Modify( + dst_vector_access_dim, + SliceLengths::Get(dst_vector_access_dim) / dst_data_per_access); + + static_ford{}([&](auto dst_access_id) { + constexpr auto dst_data_id = dst_access_id.Modify( + dst_vector_access_dim, + dst_access_id[dst_vector_access_dim] * dst_data_per_access); + + vector_t vector_data; + + // pack vector from buffer + static_for<0, DstDataPerAccess, 1>{}([&](auto i) { + constexpr auto scalar_id = + typename uniform_sequence_gen::type{}.Modify(dst_vector_access_dim, + i); + + constexpr index_t buffer_offset = + buffer_desc.GetOffsetFromMultiIndex(dst_data_id + scalar_id); + + reinterpret_cast(&vector_data)[i] = p_buffer[buffer_offset]; + }); + + const index_t dst_offset = + DstDesc::GetOffsetFromMultiIndex(mDstSliceOrigin + dst_data_id); + + // store vector into dst + *reinterpret_cast(&p_dst[dst_offset]) = vector_data; + }); + } } private: - Array mSrcSliceOrigin; - Array mDstSliceOrigin; + Array mSrcSliceOrigin; + Array mDstSliceOrigin; }; #endif diff --git a/composable_kernel/include/utility/functional2.hpp b/composable_kernel/include/utility/functional2.hpp index c49341b666..289b9d9b3f 100644 --- a/composable_kernel/include/utility/functional2.hpp +++ b/composable_kernel/include/utility/functional2.hpp @@ -23,14 +23,16 @@ struct static_for_impl> template struct static_for { + __host__ __device__ constexpr static_for() + { + static_assert(NBegin <= NEnd, "wrongs! should have NBegin <= NEnd"); + static_assert((NEnd - NBegin) % Increment == 0, + "Wrong! should satisfy (NEnd - NBegin) % Increment == 0"); + } + template __host__ __device__ constexpr void operator()(F f) const { - static_assert(NBegin <= NEnd, "wrongs! should have NBegin <= NEnd"); - - static_assert((NEnd - NBegin) % Increment == 0, - "Wrong! should satisfy (NEnd - NBegin) % Increment == 0"); - static_for_impl::type>{}(f); } }; diff --git a/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp index 67395b978d..65f03dc1bb 100644 --- a/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp @@ -94,6 +94,41 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; +#elif 1 + // each thread hold 64 data + constexpr index_t BlockSize = 256; + + constexpr index_t BPerBlock = 16; + constexpr index_t KPerBlock = 128; + constexpr index_t EPerBlock = 8; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>; + using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>; + using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] + using InBlockCopySrcAccessOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] + using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] + + constexpr index_t InBlockCopySrcDataPerRead_B = 1; + constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4; + + using WeiBlockCopySubLengths_E_K = Sequence<2, 2>; + using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>; + using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] + using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] + using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] + + constexpr index_t WeiBlockCopySrcDataPerRead_E = 2; + constexpr index_t WeiBlockCopyDstDataPerWrite_K = 2; #elif 0 // each thread hold 32 data constexpr index_t BlockSize = 256; diff --git a/driver/src/driver.cpp b/driver/src/driver.cpp index 540f81186c..d70483926d 100644 --- a/driver/src/driver.cpp +++ b/driver/src/driver.cpp @@ -9,14 +9,14 @@ #include "conv_common.hpp" #include "host_conv.hpp" #include "device_convolution_direct_v2_nchw_kcyx_nkhw.hpp" -#include "device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp" -#include "device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp" -#include "device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp" -#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp" +//#include "device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp" +//#include "device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp" +//#include "device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp" +//#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp" #include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp" -#include "device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw.hpp" -#include "device_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw.hpp" -#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" +//#include "device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw.hpp" +//#include "device_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw.hpp" +//#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" struct GeneratorTensor_1 {